From 91aba7baac3d2a079c0b13db25588842260c98cc Mon Sep 17 00:00:00 2001 From: YangQun1 Date: Thu, 25 Jul 2024 13:04:23 +0000 Subject: [PATCH] Fix py codegen to delete values that don't have any users (#131028) Fixes #131025 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131028 Approved by: https://github.com/ezyang --- test/dynamo/test_autograd_function.py | 6 +- test/dynamo/test_comptime.py | 4 +- test/dynamo/test_ctx_manager.py | 34 +- test/dynamo/test_export.py | 24 +- test/dynamo/test_higher_order_ops.py | 764 ++++++++++----------- test/dynamo/test_input_attr_tracking.py | 6 +- test/dynamo/test_misc.py | 4 +- test/dynamo/test_repros.py | 2 +- test/export/test_export.py | 22 +- test/export/test_passes.py | 6 +- test/export/test_torchbind.py | 22 +- test/functorch/test_aotdispatch.py | 64 +- test/functorch/test_control_flow.py | 38 +- test/functorch/test_eager_transforms.py | 22 +- test/higher_order_ops/test_with_effects.py | 2 +- test/inductor/test_flex_attention.py | 20 +- test/test_functionalization.py | 286 ++++---- test/test_fx.py | 25 + test/test_fx_reinplace_pass.py | 38 +- test/test_proxy_tensor.py | 20 +- torch/fx/graph.py | 7 + 21 files changed, 724 insertions(+), 692 deletions(-) diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 333438757aee..acc8999f26ab 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -519,7 +519,7 @@ class GraphModule(torch.nn.Module): l_weird_b = L_weird_b l_weird_c = L_weird_c - function_ctx = torch.autograd.function.FunctionCtx() + function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None fwd_body_0 = self.fwd_body_0 bwd_body_0 = self.bwd_body_0 autograd_function_apply: "f32[]" = torch._functorch.autograd_function.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True]); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None @@ -534,13 +534,13 @@ class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module): def forward(self, ctx, grad: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"): - _set_grad_enabled = torch._C._set_grad_enabled(False) + _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None mul: "f32[]" = grad * l_weird_b; l_weird_b = None mul_1: "f32[]" = mul * l_weird_c; mul = l_weird_c = None mul_2: "f32[]" = grad * 2; grad = None - _set_grad_enabled_1 = torch._C._set_grad_enabled(True) + _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None return (mul_1, mul_2) """, ) diff --git a/test/dynamo/test_comptime.py b/test/dynamo/test_comptime.py index a14c889a3bce..48dc3d7d59ef 100644 --- a/test/dynamo/test_comptime.py +++ b/test/dynamo/test_comptime.py @@ -95,7 +95,7 @@ s0""", """\ def forward(self, L_x_ : torch.Tensor): l_x_ = L_x_ - y = l_x_ * 2; l_x_ = None""", + y = l_x_ * 2; l_x_ = y = None""", ) def test_print_disas(self): @@ -391,7 +391,7 @@ y = TensorVariable() def forward(self, L_x_ : torch.Tensor): l_x_ = L_x_ y = l_x_ * 2; l_x_ = None - add = y + 4; y = None""", + add = y + 4; y = add = None""", ) diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 47f8e8eeb863..a326791635f3 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -1057,7 +1057,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): """\ class GraphModule(torch.nn.Module): def forward(self): - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported') + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported'); _saved_tensors_hooks_disable = None x: "f32[1]" = torch.ones(1) @@ -1065,9 +1065,9 @@ class GraphModule(torch.nn.Module): add: "f32[1]" = x + y; x = y = None - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (add,) -""", +""", # NOQA: B950 ) def test_disable_saved_tensors_hooks_prev_disabled(self): @@ -1097,7 +1097,7 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self): - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported') + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported'); _saved_tensors_hooks_disable = None x: "f32[1]" = torch.ones(1) @@ -1105,9 +1105,9 @@ class GraphModule(torch.nn.Module): add: "f32[1]" = x + y; x = y = None - _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message') + _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message'); _saved_tensors_hooks_disable_1 = None return (add,) -""", +""", # NOQA: B950 ) def test_disable_saved_tensors_hooks_prev_disabled_nested(self): @@ -1143,23 +1143,23 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self): - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported') + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported'); _saved_tensors_hooks_disable = None x: "f32[1]" = torch.ones(1) y: "f32[1]" = torch.zeros(1) - _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('This is not supported inner') + _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('This is not supported inner'); _saved_tensors_hooks_disable_1 = None add: "f32[1]" = x + y; y = None - _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable('This is not supported') + _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable('This is not supported'); _saved_tensors_hooks_disable_2 = None add_1: "f32[1]" = add + x; add = x = None - _saved_tensors_hooks_disable_3 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message') + _saved_tensors_hooks_disable_3 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message'); _saved_tensors_hooks_disable_3 = None return (add_1,) -""", +""", # NOQA: B950 ) def test_disable_saved_tensors_hooks_graph_break(self): @@ -1186,13 +1186,13 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[]"): l_x_ = L_x_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported') + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported'); _saved_tensors_hooks_disable = None y: "f32[]" = l_x_ + 1; l_x_ = None - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (y,) -""", +""", # NOQA: B950 ) graph = eager.graphs[1] @@ -1204,13 +1204,13 @@ class GraphModule(torch.nn.Module): def forward(self, L_y_: "f32[]"): l_y_ = L_y_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported') + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported'); _saved_tensors_hooks_disable = None mul: "f32[]" = l_y_ * 2; l_y_ = None - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (mul,) -""", +""", # NOQA: B950 ) def test_context_wrapping_grad_mode_decorator(self): diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index d882224797ac..192a43846c1e 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -4504,8 +4504,8 @@ def forward(self, x): l_args_0_ = arg0 _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) add = l_args_0_ + 1; l_args_0_ = None - _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = None - return pytree.tree_unflatten([add], self._out_spec)""", + _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None + return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950 ) self.assertEqual(out.requires_grad, False) with self.assertRaisesRegex( @@ -4527,8 +4527,8 @@ def forward(self, x): l_args_0_ = arg0 _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(False) add = l_args_0_ + 1; l_args_0_ = None - _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = None - return pytree.tree_unflatten([add], self._out_spec)""", + _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None + return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950 ) inp = torch.randn(2, 2) @@ -4549,8 +4549,8 @@ def forward(self, x): l_x_ = arg0 _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) add = l_x_ + 1; l_x_ = None - _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = None - return pytree.tree_unflatten([add], self._out_spec)""", + _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None + return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950 ) inp = torch.randn(2, 2, requires_grad=True) out = gm(inp) @@ -4583,10 +4583,10 @@ def forward(self, x, b, y): l_x_ = arg0 l_b_ = arg1 l_y_ = arg2 - _set_grad_enabled = torch._C._set_grad_enabled(False) + _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None x = l_x_.clone(); l_x_ = None - x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = None - _set_grad_enabled_1 = torch._C._set_grad_enabled(True) + x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = setitem = None + _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None return pytree.tree_unflatten([x], self._out_spec)""", ) @@ -4601,9 +4601,9 @@ def forward(self, x, b, y): l_y_ = arg2 _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) x = l_x_.clone(); l_x_ = None - x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = None - _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = None - return pytree.tree_unflatten([x], self._out_spec)""", + x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = setitem = None + _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None + return pytree.tree_unflatten([x], self._out_spec)""", # NOQA: B950 ) with self.assertRaisesRegex( diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 8637fe26d9e6..f8270684784a 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -1192,7 +1192,7 @@ def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor): body_graph, """\ def forward(self, child, l_y_): - child_1 = child[0] + child_1 = child[0]; child_1 = None map_body_0 = self.map_body_0 map_impl = torch.ops.higher_order.map_impl(map_body_0, [child], [l_y_]); map_body_0 = child = l_y_ = None getitem_1 = map_impl[0]; map_impl = None @@ -2826,78 +2826,78 @@ class GraphModule(torch.nn.Module): cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None neg: "i64[0]" = getitem.neg(); getitem = None - unbind = neg.unbind(); neg = None + unbind = neg.unbind(); neg = unbind = None chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12) diagonal: "f32[12]" = chunk.diagonal(0) - fill_: "f32[12]" = diagonal.fill_(1); diagonal = None + fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None child_1 = torch._C._functorch._add_batch_dim(child, 0, 1); child = None - _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)) + _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)); _jvp_treespec_compare = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting() - _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True) - _enter_dual_level = torch._C._enter_dual_level() + _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None + _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None - _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions() + _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None child_2 = torch._make_dual(l_x_, child_1, level = 0); child_1 = None - _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None + _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None diff_primals = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None o = torch.sin(diff_primals) results = torch._C._functorch._unwrap_for_grad(o, 3) - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None tensor_1 = torch.tensor((12,)) cumsum_1 = tensor_1.cumsum(dim = 0); tensor_1 = None getitem_1 = cumsum_1[slice(None, -1, None)]; cumsum_1 = None neg_1 = getitem_1.neg(); getitem_1 = None - unbind_1 = neg_1.unbind(); neg_1 = None + unbind_1 = neg_1.unbind(); neg_1 = unbind_1 = None chunk_1 = results.new_zeros(12, 12); results = None diagonal_1 = chunk_1.diagonal(0) - fill__1 = diagonal_1.fill_(1); diagonal_1 = None + fill__1 = diagonal_1.fill_(1); diagonal_1 = fill__1 = None basis = chunk_1.view(12, 4, 3); chunk_1 = None - lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None - _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error') + _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None _add_batch_dim_1 = torch._C._functorch._add_batch_dim(basis, 0, 3); basis = None - _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1) + _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1); _vjp_treespec_compare = None _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim_1 = None batched_outputs = _autograd_grad[0]; _autograd_grad = None chunked_result = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0); batched_outputs = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None split = chunked_result.split((12,), dim = 0); chunked_result = None split_1 = split[0]; split = None @@ -2908,17 +2908,17 @@ class GraphModule(torch.nn.Module): primal = _unpack_dual[0] dual = _unpack_dual[1]; _unpack_dual = None - primals_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None + primals_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None - _exit_dual_level = torch._C._exit_dual_level(0) - _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True) - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting() + _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None + _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None results_1: "f32[12, 4, 3, 4, 3]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None - _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None movedim: "f32[4, 3, 4, 3, 12]" = results_1.movedim(0, -1); results_1 = None split_2 = movedim.split((12,), dim = -1); movedim = None @@ -2958,80 +2958,80 @@ class GraphModule(torch.nn.Module): cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None neg: "i64[0]" = getitem.neg(); getitem = None - unbind = neg.unbind(); neg = None + unbind = neg.unbind(); neg = unbind = None chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12) diagonal: "f32[12]" = chunk.diagonal(0) - fill_: "f32[12]" = diagonal.fill_(1); diagonal = None + fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None child_1 = torch._C._functorch._add_batch_dim(child, 0, 1); child = None - _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)) + _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)); _jvp_treespec_compare = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting() - _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True) - _enter_dual_level = torch._C._enter_dual_level() + _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None + _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None - _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions() + _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None child_3 = torch._make_dual(l_y_, child_1, level = 0); child_1 = None child_2 = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None - _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = None + _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None _wrap_for_grad_2 = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None child_4 = torch._C._functorch._wrap_for_grad(child_3, 3); child_3 = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child_4) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child_4); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None o = _wrap_for_grad_2.sin(); _wrap_for_grad_2 = None results = torch._C._functorch._unwrap_for_grad(o, 3) - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None tensor_1 = torch.tensor((12,)) cumsum_1 = tensor_1.cumsum(dim = 0); tensor_1 = None getitem_1 = cumsum_1[slice(None, -1, None)]; cumsum_1 = None neg_1 = getitem_1.neg(); getitem_1 = None - unbind_1 = neg_1.unbind(); neg_1 = None + unbind_1 = neg_1.unbind(); neg_1 = unbind_1 = None chunk_1 = results.new_zeros(12, 12); results = None diagonal_1 = chunk_1.diagonal(0) - fill__1 = diagonal_1.fill_(1); diagonal_1 = None + fill__1 = diagonal_1.fill_(1); diagonal_1 = fill__1 = None basis = chunk_1.view(12, 4, 3); chunk_1 = None - lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None - _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error') + _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None _add_batch_dim_1 = torch._C._functorch._add_batch_dim(basis, 0, 3); basis = None - _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1) + _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1); _vjp_treespec_compare = None _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = child_4 = _add_batch_dim_1 = None child_5 = _autograd_grad[0]; _autograd_grad = None child_6 = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0); child_5 = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None split = child_6.split((12,), dim = 0); child_6 = None split_1 = split[0]; split = None @@ -3043,17 +3043,17 @@ class GraphModule(torch.nn.Module): tangent = torch.zeros_like(primal) - child_8: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None + child_8: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_8 = None child_9: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None - _exit_dual_level = torch._C._exit_dual_level(0) - _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True) - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting() + _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None + _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None child_10: "f32[12, 4, 3, 3, 4]" = torch._C._functorch._remove_batch_dim(child_9, 1, 12, 0); child_9 = None - _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None movedim: "f32[4, 3, 3, 4, 12]" = child_10.movedim(0, -1); child_10 = None split_2 = movedim.split((12,), dim = -1); movedim = None @@ -3113,51 +3113,51 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[4, 3]"): l_x_ = L_x_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None diff_primals = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None o = torch.sin(diff_primals) results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 1) - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None tensor: "i64[1]" = torch.tensor((12,)) cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None neg: "i64[0]" = getitem.neg(); getitem = None - unbind = neg.unbind(); neg = None + unbind = neg.unbind(); neg = unbind = None chunk: "f32[12, 12]" = results.new_zeros(12, 12); results = None diagonal: "f32[12]" = chunk.diagonal(0) - fill_: "f32[12]" = diagonal.fill_(1); diagonal = None + fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None basis: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None _add_batch_dim = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None - _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) + _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim); _vjp_treespec_compare = None _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None batched_outputs = _autograd_grad[0]; _autograd_grad = None chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None split = chunked_result.split((12,), dim = 0); chunked_result = None split_1: "f32[12, 4, 3]" = split[0]; split = None @@ -3192,52 +3192,52 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_y_ = L_y_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None - _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None + _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = _wrap_for_grad = None diff_primals = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None o = diff_primals.sin() results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(o, 1) - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None tensor: "i64[1]" = torch.tensor((12,)) cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None neg: "i64[0]" = getitem.neg(); getitem = None - unbind = neg.unbind(); neg = None + unbind = neg.unbind(); neg = unbind = None chunk: "f32[12, 12]" = results.new_zeros(12, 12); results = None diagonal: "f32[12]" = chunk.diagonal(0) - fill_: "f32[12]" = diagonal.fill_(1); diagonal = None + fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None basis: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None _add_batch_dim = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None - _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) + _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim); _vjp_treespec_compare = None _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None batched_outputs = _autograd_grad[0]; _autograd_grad = None chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None split = chunked_result.split((12,), dim = 0); chunked_result = None split_1: "f32[12, 3, 4]" = split[0]; split = None @@ -3272,17 +3272,17 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_y_ = L_y_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None aux = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None diff_primals = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None o = diff_primals.sin() @@ -3290,36 +3290,36 @@ class GraphModule(torch.nn.Module): results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(o, 1) - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None tensor: "i64[1]" = torch.tensor((12,)) cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None neg: "i64[0]" = getitem.neg(); getitem = None - unbind = neg.unbind(); neg = None + unbind = neg.unbind(); neg = unbind = None chunk: "f32[12, 12]" = results.new_zeros(12, 12); results = None diagonal: "f32[12]" = chunk.diagonal(0) - fill_: "f32[12]" = diagonal.fill_(1); diagonal = None + fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None basis: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None _add_batch_dim = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None - _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) + _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim); _vjp_treespec_compare = None _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None batched_outputs = _autograd_grad[0]; _autograd_grad = None chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None split = chunked_result.split((12,), dim = 0); chunked_result = None split_1: "f32[12, 3, 4]" = split[0]; split = None @@ -3381,24 +3381,24 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[5]"): l_x_ = L_x_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None child = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - child_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child) + child_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None sin = child.sin(); child = None o = sin.sum(); sin = None results: "f32[]" = torch._C._functorch._unwrap_for_grad(o, 1); o = None - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (results,) """, ) @@ -3429,16 +3429,16 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_v_ = L_v_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None child = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None child_3 = torch._functorch.eager_transforms._set_tensor_requires_grad(child) - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None child_1 = child.sin() child_2 = child.cos(); child = None @@ -3446,10 +3446,10 @@ class GraphModule(torch.nn.Module): _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1) - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None - _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare((child_1, child_2), (l_v_, l_v_)) + _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare((child_1, child_2), (l_v_, l_v_)); _vjp_treespec_compare = None _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, l_v_], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = None getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None @@ -3483,16 +3483,16 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_v_ = L_v_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None child = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None child_3 = torch._functorch.eager_transforms._set_tensor_requires_grad(child) - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None child_1 = child.sin() child_2 = child.cos(); child = None @@ -3500,12 +3500,12 @@ class GraphModule(torch.nn.Module): _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1) - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None child_4: "f32[5]" = l_v_.sin() - _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare({'first': child_1, 'second': child_2}, {'first': l_v_, 'second': child_4}) + _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare({'first': child_1, 'second': child_2}, {'first': l_v_, 'second': child_4}); _vjp_treespec_compare = None _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, child_4], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = child_4 = None getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None @@ -3539,26 +3539,26 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[5]"): l_x_ = L_x_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None child = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - child_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child) + child_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None sin = child.sin() o = sin.sum(); sin = None - aux: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1); child = None + aux: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1); child = aux = None results: "f32[]" = torch._C._functorch._unwrap_for_grad(o, 1); o = None - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (results,) """, ) @@ -3612,16 +3612,16 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 3, 3]"): l_x_ = L_x_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None sin = diff_args.sin() output = sin.sum(); sin = None @@ -3631,10 +3631,10 @@ class GraphModule(torch.nn.Module): grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (grad_input_1,) """, ) @@ -3679,16 +3679,16 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 3, 3]"): l_x_ = L_x_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None sin = diff_args.sin() add = sin + 3; sin = None @@ -3699,10 +3699,10 @@ class GraphModule(torch.nn.Module): grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (grad_input_1,) """, ) @@ -3736,16 +3736,16 @@ class GraphModule(torch.nn.Module): y: "f32[3]" = torch.randn(3) - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None sin = diff_args.sin() add = sin + y; sin = None @@ -3756,10 +3756,10 @@ class GraphModule(torch.nn.Module): grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (y, grad_input_1) """, ) @@ -3793,16 +3793,16 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 3, 3]"): l_x_ = L_x_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None sin = diff_args.sin() add = sin + 3.14; sin = None @@ -3813,10 +3813,10 @@ class GraphModule(torch.nn.Module): grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (grad_input_1,) """, ) @@ -3847,16 +3847,16 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 3, 3]"): l_x_ = L_x_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None sin = diff_args.sin() add = sin + 3.14; sin = None @@ -3868,12 +3868,12 @@ class GraphModule(torch.nn.Module): grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (grad_input_1, aux_1) """, ) @@ -3904,17 +3904,17 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_y_ = L_y_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None sin = diff_args.sin() add = sin + _wrap_for_grad_1; sin = _wrap_for_grad_1 = None @@ -3926,12 +3926,12 @@ class GraphModule(torch.nn.Module): grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (grad_input_1, aux_1) """, ) @@ -3973,22 +3973,22 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_y_ = L_y_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None child = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None child_1 = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) - set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None + set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None - _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1) + _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1); _set_tensor_requires_grad_1 = None - set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None sin = child.sin() add = sin + child_1; sin = None @@ -4002,12 +4002,12 @@ class GraphModule(torch.nn.Module): _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (_unwrap_for_grad, _unwrap_for_grad_1, aux_1) """, ) @@ -4019,22 +4019,22 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_y_ = L_y_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None child = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None child_1 = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) - set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None + set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None - _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1) + _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1); _set_tensor_requires_grad_1 = None - set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None sin = child.sin() add = sin + child_1; sin = None @@ -4048,12 +4048,12 @@ class GraphModule(torch.nn.Module): _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (_unwrap_for_grad, _unwrap_for_grad_1, aux_1) """, ) @@ -4081,26 +4081,26 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[]"): l_x_ = L_x_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) - _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting_1 = torch._C._functorch._grad_increment_nesting() + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None + _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable_1 = None + _grad_increment_nesting_1 = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting_1 = None diff_args_1 = torch._C._functorch._wrap_for_grad(diff_args, 2) - set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None - _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args_1) + _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args_1); _set_tensor_requires_grad_1 = None - set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None sin = diff_args_1.sin() output = sin.sum(); sin = None @@ -4110,20 +4110,20 @@ class GraphModule(torch.nn.Module): grad_input_1 = torch._C._functorch._unwrap_for_grad(grad_input, 2); grad_input = None - output_1 = torch._C._functorch._unwrap_for_grad(output, 2); output = None + output_1 = torch._C._functorch._unwrap_for_grad(output, 2); output = output_1 = None - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable_2 = None _autograd_grad_1 = torch._functorch.eager_transforms._autograd_grad((grad_input_1,), [diff_args], create_graph = True); diff_args = None grad_input_2 = _autograd_grad_1[0]; _autograd_grad_1 = None grad_input_3: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_2, 1); grad_input_2 = None - output_2: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_1, 1); grad_input_1 = None + output_2: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_1, 1); grad_input_1 = output_2 = None - _grad_decrement_nesting_1 = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting_1 = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting_1 = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (grad_input_3,) """, ) @@ -4206,16 +4206,16 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 3, 3]"): l_x_ = L_x_ - _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.") - _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None - set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args) + _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None - set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False) + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None sin = diff_args.sin() sum_1 = sin.sum(); sin = None @@ -4226,10 +4226,10 @@ class GraphModule(torch.nn.Module): grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None - _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() - _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (grad_input_1,) """, ) @@ -4299,32 +4299,32 @@ class GraphModule(torch.nn.Module): cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None neg: "i64[0]" = getitem.neg(); getitem = None - unbind = neg.unbind(); neg = None + unbind = neg.unbind(); neg = unbind = None chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12) diagonal: "f32[12]" = chunk.diagonal(0) - fill_: "f32[12]" = diagonal.fill_(1); diagonal = None + fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None child_1 = torch._C._functorch._add_batch_dim(child, 0, 1); child = None - _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)) + _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)); _jvp_treespec_compare = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting() - _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True) - _enter_dual_level = torch._C._enter_dual_level() + _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None + _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None - _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions() + _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None _make_dual = torch._make_dual(l_x_, child_1, level = 0); child_1 = None - _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None + _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None result_duals = torch.sin(_make_dual); _make_dual = None @@ -4332,17 +4332,17 @@ class GraphModule(torch.nn.Module): primal = _unpack_dual[0] dual = _unpack_dual[1]; _unpack_dual = None - primals_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None + primals_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None - _exit_dual_level = torch._C._exit_dual_level(0) - _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True) - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting() + _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None + _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None results: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None movedim: "f32[4, 3, 12]" = results.movedim(0, -1); results = None split = movedim.split((12,), dim = -1); movedim = None @@ -4382,33 +4382,33 @@ class GraphModule(torch.nn.Module): cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None neg: "i64[0]" = getitem.neg(); getitem = None - unbind = neg.unbind(); neg = None + unbind = neg.unbind(); neg = unbind = None chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12) diagonal: "f32[12]" = chunk.diagonal(0) - fill_: "f32[12]" = diagonal.fill_(1); diagonal = None + fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None child_1 = torch._C._functorch._add_batch_dim(child, 0, 1); child = None - _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)) + _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)); _jvp_treespec_compare = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting() - _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True) - _enter_dual_level = torch._C._enter_dual_level() + _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None + _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None - _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions() + _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None _make_dual = torch._make_dual(l_y_, child_1, level = 0); child_1 = None - _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None - _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = None + _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None + _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None result_duals = _make_dual.sin(); _make_dual = None @@ -4416,17 +4416,17 @@ class GraphModule(torch.nn.Module): primal = _unpack_dual[0] dual = _unpack_dual[1]; _unpack_dual = None - primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None + primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None - _exit_dual_level = torch._C._exit_dual_level(0) - _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True) - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting() + _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None + _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None results: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None movedim: "f32[3, 4, 12]" = results.movedim(0, -1); results = None split = movedim.split((12,), dim = -1); movedim = None @@ -4466,33 +4466,33 @@ class GraphModule(torch.nn.Module): cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None neg: "i64[0]" = getitem.neg(); getitem = None - unbind = neg.unbind(); neg = None + unbind = neg.unbind(); neg = unbind = None chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12) diagonal: "f32[12]" = chunk.diagonal(0) - fill_: "f32[12]" = diagonal.fill_(1); diagonal = None + fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None child_1 = torch._C._functorch._add_batch_dim(child, 0, 1); child = None - _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)) + _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)); _jvp_treespec_compare = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting() - _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True) - _enter_dual_level = torch._C._enter_dual_level() + _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None + _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None - _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions() + _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None _make_dual = torch._make_dual(l_y_, child_1, level = 0); child_1 = None aux = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None - _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = None + _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None result_duals = _make_dual.sin(); _make_dual = None @@ -4502,18 +4502,18 @@ class GraphModule(torch.nn.Module): primal = _unpack_dual[0] dual = _unpack_dual[1]; _unpack_dual = None - primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None + primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None - _exit_dual_level = torch._C._exit_dual_level(0) - _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True) - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting() + _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None + _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None results: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None aux_2: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(aux_1, 1, 12, 0); aux_1 = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None aux_3: "f32[4, 3]" = aux_2[0]; aux_2 = None @@ -4555,32 +4555,32 @@ class GraphModule(torch.nn.Module): cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None neg: "i64[0]" = getitem.neg(); getitem = None - unbind = neg.unbind(); neg = None + unbind = neg.unbind(); neg = unbind = None chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12) diagonal: "f32[12]" = chunk.diagonal(0) - fill_: "f32[12]" = diagonal.fill_(1); diagonal = None + fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'same') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'same'); _vmap_increment_nesting = None child_1 = torch._C._functorch._add_batch_dim(child, 0, 1); child = None - _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)) + _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)); _jvp_treespec_compare = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting() - _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True) - _enter_dual_level = torch._C._enter_dual_level() + _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None + _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None - _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions() + _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None child_3 = torch._make_dual(l_x_, child_1, level = 0); child_1 = None - _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None + _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = None child_2 = _wrap_for_grad_1.sin(); _wrap_for_grad_1 = None @@ -4594,20 +4594,20 @@ class GraphModule(torch.nn.Module): primal_1 = _unpack_dual_1[0] dual = _unpack_dual_1[1]; _unpack_dual_1 = None - child_4: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None - child_5: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 2); primal_1 = None + child_4: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_4 = None + child_5: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 2); primal_1 = child_5 = None child_6: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None child_7 = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None - _exit_dual_level = torch._C._exit_dual_level(0) - _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True) - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting() + _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None + _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None child_8: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(child_6, 1, 12, 0); child_6 = None child_9: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(child_7, 1, 12, 0); child_7 = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None movedim: "f32[3, 4, 12]" = child_8.movedim(0, -1); child_8 = None split = movedim.split((12,), dim = -1); movedim = None @@ -4676,13 +4676,13 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_v_ = L_v_ - _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)) + _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)); _jvp_treespec_compare = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting() - _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True) - _enter_dual_level = torch._C._enter_dual_level() + _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None + _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None - _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions() + _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None _make_dual = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None @@ -4697,9 +4697,9 @@ class GraphModule(torch.nn.Module): tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None - _exit_dual_level = torch._C._exit_dual_level(0) - _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True) - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting() + _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None + _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None return (primals_out_unflatten, tangents_out_unflatten) """, ) @@ -4730,13 +4730,13 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_v_ = L_v_ - _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)) + _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)); _jvp_treespec_compare = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting() - _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True) - _enter_dual_level = torch._C._enter_dual_level() + _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None + _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None - _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions() + _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None aux = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None @@ -4753,9 +4753,9 @@ class GraphModule(torch.nn.Module): tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None - _exit_dual_level = torch._C._exit_dual_level(0) - _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True) - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting() + _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None + _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None return (primals_out_unflatten, tangents_out_unflatten, aux_1) """, ) @@ -4788,17 +4788,17 @@ class GraphModule(torch.nn.Module): l_y_ = L_y_ l_v_ = L_v_ - _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_, l_y_), (l_v_, l_v_)) + _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_, l_y_), (l_v_, l_v_)); _jvp_treespec_compare = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting() - _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True) - _enter_dual_level = torch._C._enter_dual_level() + _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None + _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None - _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions() + _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None aux = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = None - _maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions() + _maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions_1 = None _make_dual_1 = torch._make_dual(l_y_, l_v_, level = 0); l_y_ = l_v_ = None @@ -4817,9 +4817,9 @@ class GraphModule(torch.nn.Module): tangents_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None - _exit_dual_level = torch._C._exit_dual_level(0) - _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True) - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting() + _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None + _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None return (primals_out_unflatten, tangents_out_unflatten, aux_1) """, ) @@ -4851,15 +4851,15 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_v_ = L_v_ - _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(False) + _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled = None - _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)) + _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)); _jvp_treespec_compare = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting() - _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True) - _enter_dual_level = torch._C._enter_dual_level() + _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None + _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None - _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions() + _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None _make_dual = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None @@ -4874,10 +4874,10 @@ class GraphModule(torch.nn.Module): tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None - _exit_dual_level = torch._C._exit_dual_level(0) - _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(False) - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting() - _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True) + _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_2 = None + _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None + _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None return (primals_out_unflatten, tangents_out_unflatten) """, ) @@ -4920,17 +4920,17 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_v_ = L_v_ - _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(False) - _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True) - _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(False) + _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled = None + _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None + _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_2 = None - _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)) + _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)); _jvp_treespec_compare = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting() - _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True) - _enter_dual_level = torch._C._enter_dual_level() + _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None + _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None - _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions() + _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None _make_dual = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None @@ -4945,12 +4945,12 @@ class GraphModule(torch.nn.Module): tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None - _exit_dual_level = torch._C._exit_dual_level(0) - _set_fwd_grad_enabled_4 = torch._C._set_fwd_grad_enabled(False) - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting() - _set_fwd_grad_enabled_5 = torch._C._set_fwd_grad_enabled(True) - _set_fwd_grad_enabled_6 = torch._C._set_fwd_grad_enabled(False) - _set_fwd_grad_enabled_7 = torch._C._set_fwd_grad_enabled(True) + _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _set_fwd_grad_enabled_4 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_4 = None + _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None + _set_fwd_grad_enabled_5 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_5 = None + _set_fwd_grad_enabled_6 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_6 = None + _set_fwd_grad_enabled_7 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_7 = None return (primals_out_unflatten, tangents_out_unflatten) """, ) @@ -4997,22 +4997,22 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 3, 3]"): l_x_ = L_x_ - _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_x_,)) + _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_x_,)); _jvp_treespec_compare = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting() - _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True) - _enter_dual_level = torch._C._enter_dual_level() + _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None + _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None - _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions() + _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None child = torch._make_dual(l_x_, l_x_, level = 0); l_x_ = None - _jvp_treespec_compare_1 = torch._functorch.eager_transforms._jvp_treespec_compare((child,), (child,)) + _jvp_treespec_compare_1 = torch._functorch.eager_transforms._jvp_treespec_compare((child,), (child,)); _jvp_treespec_compare_1 = None - _jvp_increment_nesting_1 = torch._C._functorch._jvp_increment_nesting() - _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True) + _jvp_increment_nesting_1 = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting_1 = None + _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None - _maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions() + _maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions_1 = None _make_dual_1 = torch._make_dual(child, child, level = 0); child = None @@ -5026,8 +5026,8 @@ class GraphModule(torch.nn.Module): tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None - _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(True) - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting() + _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_2 = None + _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None _unpack_dual_1 = torch._unpack_dual(primals_out_unflatten, level = 0); primals_out_unflatten = None primal_1 = _unpack_dual_1[0] @@ -5042,9 +5042,9 @@ class GraphModule(torch.nn.Module): _unwrap_for_grad_4: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_1, 1); dual_1 = None _unwrap_for_grad_5: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_2, 1); dual_2 = None - _exit_dual_level = torch._C._exit_dual_level(0) - _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True) - _jvp_decrement_nesting_1 = torch._C._functorch._jvp_decrement_nesting() + _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None + _jvp_decrement_nesting_1 = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting_1 = None return (_unwrap_for_grad_2, _unwrap_for_grad_3, _unwrap_for_grad_4, _unwrap_for_grad_5) """, ) @@ -5119,7 +5119,7 @@ class GraphModule(torch.nn.Module): cos_default: "f32[3, 3, 3]" = torch.ops.aten.cos.default(alias_default_1); alias_default_1 = None - alias_default_2: "f32[3, 3, 3]" = torch.ops.aten.alias.default(sin_default) + alias_default_2: "f32[3, 3, 3]" = torch.ops.aten.alias.default(sin_default); alias_default_2 = None return (alias_default, cos_default, sin_default) """, ) @@ -5346,9 +5346,9 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 3, 3]"): l_x_ = L_x_ - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None @@ -5358,7 +5358,7 @@ class GraphModule(torch.nn.Module): _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None return (_remove_batch_dim,) """, ) @@ -5384,9 +5384,9 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 3, 3]"): l_x_ = L_x_ - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None @@ -5397,7 +5397,7 @@ class GraphModule(torch.nn.Module): _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None return (_remove_batch_dim,) """, ) @@ -5424,9 +5424,9 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_y_ = L_y_ - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None @@ -5437,7 +5437,7 @@ class GraphModule(torch.nn.Module): _remove_batch_dim: "f32[3, 3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None return (_remove_batch_dim,) """, ) @@ -5465,9 +5465,9 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_y_ = L_y_ - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None _add_batch_dim_1 = torch._C._functorch._add_batch_dim(l_y_, 1, 1); l_y_ = None @@ -5479,7 +5479,7 @@ class GraphModule(torch.nn.Module): _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None return (_remove_batch_dim,) """, ) @@ -5509,9 +5509,9 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_y_ = L_y_ - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None _add_batch_dim_1 = torch._C._functorch._add_batch_dim(l_y_, 1, 1); l_y_ = None @@ -5523,7 +5523,7 @@ class GraphModule(torch.nn.Module): _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None return (_remove_batch_dim,) """, ) @@ -5549,16 +5549,16 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_y_ = L_y_ - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None child = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None child_1 = torch._C._functorch._add_batch_dim(l_y_, 0, 1); l_y_ = None - lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None - _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error') + _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None _add_batch_dim_2 = torch._C._functorch._add_batch_dim(child, 1, 2); child = None _add_batch_dim_3 = torch._C._functorch._add_batch_dim(child_1, 1, 2); child_1 = None @@ -5567,11 +5567,11 @@ class GraphModule(torch.nn.Module): batched_outputs_1 = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None _remove_batch_dim_1: "f32[3, 3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 3, 0); batched_outputs_1 = None - _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None return (_remove_batch_dim_1,) """, ) @@ -5598,15 +5598,15 @@ class GraphModule(torch.nn.Module): l_y_ = L_y_ l_x_ = L_x_ - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(5, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(5, 'error'); _vmap_increment_nesting = None child = torch._C._functorch._add_batch_dim(l_y_, 0, 1); l_y_ = None - lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None - _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error') + _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None _add_batch_dim_1 = torch._C._functorch._add_batch_dim(child, 0, 2); child = None @@ -5614,11 +5614,11 @@ class GraphModule(torch.nn.Module): batched_outputs_1 = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None _remove_batch_dim_1: "f32[5, 3, 2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 5, 0); batched_outputs_1 = None - _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None return (_remove_batch_dim_1,) """, ) @@ -5643,9 +5643,9 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[2, 4, 3]"): l_x_ = L_x_ - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None @@ -5655,7 +5655,7 @@ class GraphModule(torch.nn.Module): _remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 0); child = None _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None return (_remove_batch_dim, _remove_batch_dim_1) """, ) @@ -5680,9 +5680,9 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[2, 4, 3]"): l_x_ = L_x_ - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None @@ -5692,7 +5692,7 @@ class GraphModule(torch.nn.Module): _remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1); child = None _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None return (_remove_batch_dim, _remove_batch_dim_1) """, ) @@ -5718,9 +5718,9 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[2, 4, 3]"): l_x_ = L_x_ - lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions() + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None - _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error') + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None @@ -5730,7 +5730,7 @@ class GraphModule(torch.nn.Module): _remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1); child = None _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None - _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting() + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None return (_remove_batch_dim, _remove_batch_dim_1) """, ) diff --git a/test/dynamo/test_input_attr_tracking.py b/test/dynamo/test_input_attr_tracking.py index f402d54eb966..9aa3a14d517b 100644 --- a/test/dynamo/test_input_attr_tracking.py +++ b/test/dynamo/test_input_attr_tracking.py @@ -320,13 +320,13 @@ class GraphModule(torch.nn.Module): detach: "f32[2, 2]" = l_y_.detach() - _set_grad_enabled = torch._C._set_grad_enabled(False) + _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None set_: "f32[2, 2]" = torch_Tensor_set_(l_x_, detach); detach = None - _set_grad_enabled_1 = torch._C._set_grad_enabled(True) + _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None - _lower_version_count_by_1 = torch__dynamo_variables_builtin__lower_version_count_by_1(set_); set_ = None + _lower_version_count_by_1 = torch__dynamo_variables_builtin__lower_version_count_by_1(set_); set_ = _lower_version_count_by_1 = None mul: "f32[2, 2]" = l_x_ * l_y_; l_x_ = l_y_ = None return (mul,) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 794ccb8b1328..6e4b6151745b 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -765,7 +765,7 @@ class MiscTests(torch._inductor.test_case.TestCase): """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None + foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None return ()""", ) @@ -916,7 +916,7 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3 """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg1_1 = arg0_1 = None + foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None return ()""", ) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index d58b527ea2cd..ca4d1f02e4ef 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -4665,7 +4665,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): getitem_2 = l_x_[0] sum_1 = getitem_2.sum(); getitem_2 = None gt_1 = sum_1 > 0; sum_1 = None - _assert_async = torch._assert_async(gt_1, 'assertion error'); gt_1 = None + _assert_async = torch._assert_async(gt_1, 'assertion error'); gt_1 = _assert_async = None cos = l_x_.cos(); l_x_ = None return (cos,)""", ) diff --git a/test/export/test_export.py b/test/export/test_export.py index 731abd54ec8b..50c232cd676a 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3145,9 +3145,9 @@ def forward(self, x): getitem = _native_batch_norm_legit_functional[0] getitem_3 = _native_batch_norm_legit_functional[3] getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None - copy__default = torch.ops.aten.copy_.default(bn_running_mean, getitem_3); bn_running_mean = getitem_3 = None - copy__default_1 = torch.ops.aten.copy_.default(bn_running_var, getitem_4); bn_running_var = getitem_4 = None - copy__default_2 = torch.ops.aten.copy_.default(bn_num_batches_tracked, add); bn_num_batches_tracked = add = None + copy__default = torch.ops.aten.copy_.default(bn_running_mean, getitem_3); bn_running_mean = getitem_3 = copy__default = None + copy__default_1 = torch.ops.aten.copy_.default(bn_running_var, getitem_4); bn_running_var = getitem_4 = copy__default_1 = None + copy__default_2 = torch.ops.aten.copy_.default(bn_num_batches_tracked, add); bn_num_batches_tracked = add = copy__default_2 = None return pytree.tree_unflatten((getitem,), self._out_spec)""", ) @@ -5720,15 +5720,15 @@ def forward(self, x): """\ def forward(self, x): item = torch.ops.aten.item.default(x); x = None - sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item) + sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item); sym_constrain_range_for_size_default = None ge = item >= 3 - _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 3 on node 'ge'"); ge = None + _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 3 on node 'ge'"); ge = _assert_scalar_default = None le = item <= 5 - _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u1 <= 5 on node 'le'"); le = None + _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u1 <= 5 on node 'le'"); le = _assert_scalar_default_1 = None gt = item > 2 - _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(gt, "Runtime assertion failed for expression 2 < u1 on node 'gt'"); gt = None + _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(gt, "Runtime assertion failed for expression 2 < u1 on node 'gt'"); gt = _assert_scalar_default_2 = None lt = item < 6 - _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(lt, "Runtime assertion failed for expression u1 < 6 on node 'lt'"); lt = None + _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(lt, "Runtime assertion failed for expression u1 < 6 on node 'lt'"); lt = _assert_scalar_default_3 = None foo_unbacked = torch.ops.testlib.foo_unbacked.default(item); item = None return foo_unbacked""", ) @@ -5740,11 +5740,11 @@ def forward(self, x, y): sin = torch.ops.aten.sin.default(y) sum_1 = torch.ops.aten.sum.dim_IntList(sin, []); sin = None _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x); x = None - sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense) + sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense); sym_constrain_range_for_size_default = None ge_1 = _local_scalar_dense >= 3 - _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u3 >= 3 on node 'ge_1'"); ge_1 = None + _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u3 >= 3 on node 'ge_1'"); ge_1 = _assert_scalar_default = None le_1 = _local_scalar_dense <= 5; _local_scalar_dense = None - _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u3 <= 5 on node 'le_1'"); le_1 = None + _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u3 <= 5 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None full = torch.ops.aten.full.default([4, 4], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) add = torch.ops.aten.add.Tensor(y, sum_1); y = sum_1 = None sum_2 = torch.ops.aten.sum.dim_IntList(full, []); full = None diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 9032b7fc2f4c..8e0f3cf31d9d 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -812,7 +812,7 @@ def forward(self, x1, x2): new_gm.submod_1.code.strip("\n"), """\ def forward(self, x1, x2): - _set_grad_enabled = torch._C._set_grad_enabled(True) + _set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None add = torch.ops.aten.add.Tensor(x1, 1); x1 = None add_1 = torch.ops.aten.add.Tensor(x2, 1); x2 = None return (add, add_1) @@ -822,7 +822,7 @@ def forward(self, x1, x2): new_gm.submod_2.code.strip("\n"), """\ def forward(self, add, add_1): - _set_grad_enabled_1 = torch._C._set_grad_enabled(False) + _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None sin = torch.ops.aten.sin.default(add); add = None cos = torch.ops.aten.cos.default(add_1); add_1 = None return (sin, cos) @@ -832,7 +832,7 @@ def forward(self, add, add_1): new_gm.submod_3.code.strip("\n"), """\ def forward(self, sin, cos): - _set_grad_enabled_2 = torch._C._set_grad_enabled(True) + _set_grad_enabled_2 = torch._C._set_grad_enabled(True); _set_grad_enabled_2 = None add_2 = torch.ops.aten.add.Tensor(sin, 1); sin = None add_3 = torch.ops.aten.add.Tensor(cos, 1); cos = None return (add_2, add_3) diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index fd180326fd3a..251ab9d29b59 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -595,14 +595,14 @@ def forward(self, token, obj_attr, x): """\ def forward(self, arg0_1, arg1_1): cos = torch.ops.aten.cos.default(arg1_1) - call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'push', cos); cos = None + call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'push', cos); cos = call_torchbind = None sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None - call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'push', sin); sin = None + call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'push', sin); sin = call_torchbind_1 = None call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop') - call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size') + call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size'); call_torchbind_3 = None add = torch.ops.aten.add.Tensor(call_torchbind_2, 1); call_torchbind_2 = None call_torchbind_4 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop') - call_torchbind_5 = torch.ops.higher_order.call_torchbind(arg0_1, 'size') + call_torchbind_5 = torch.ops.higher_order.call_torchbind(arg0_1, 'size'); call_torchbind_5 = None sub = torch.ops.aten.sub.Tensor(call_torchbind_4, 0); call_torchbind_4 = None return (sub, add, arg0_1) """, @@ -656,11 +656,11 @@ def forward(self, arg0_1, arg1_1): """\ def forward(self, arg0_1, arg1_1): call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'pop') - call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'size') + call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'size'); call_torchbind_1 = None add = torch.ops.aten.add.Tensor(call_torchbind, 1); call_torchbind = None add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = None call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop') - call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size') + call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size'); call_torchbind_3 = None sub = torch.ops.aten.sub.Tensor(call_torchbind_2, 0); call_torchbind_2 = None add_2 = torch.ops.aten.add.Tensor(sub, arg1_1); sub = arg1_1 = None return (add_2, add_1, arg0_1) @@ -917,14 +917,14 @@ def forward(self, token, safe_obj): """\ def forward(self, arg0_1, arg1_1): cos = torch.ops.aten.cos.default(arg1_1) - queue_push = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, cos); cos = None + queue_push = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, cos); cos = queue_push = None sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None - queue_push_1 = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, sin); sin = None + queue_push_1 = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, sin); sin = queue_push_1 = None queue_pop = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1) - queue_size = torch.ops._TorchScriptTesting.queue_size.default(arg0_1) + queue_size = torch.ops._TorchScriptTesting.queue_size.default(arg0_1); queue_size = None sub = torch.ops.aten.sub.Tensor(queue_pop, 1); queue_pop = None queue_pop_1 = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1) - queue_size_1 = torch.ops._TorchScriptTesting.queue_size.default(arg0_1) + queue_size_1 = torch.ops._TorchScriptTesting.queue_size.default(arg0_1); queue_size_1 = None add = torch.ops.aten.add.Tensor(queue_pop_1, 0); queue_pop_1 = None return (sub, add, arg0_1)""", ) @@ -1005,7 +1005,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): """\ def forward(self, tq, x): tq, x, = fx_pytree.tree_flatten_spec(([tq, x], {}), self._in_spec) - queue_push_default = torch.ops._TorchScriptTesting.queue_push.default(tq, x); x = None + queue_push_default = torch.ops._TorchScriptTesting.queue_push.default(tq, x); x = queue_push_default = None return pytree.tree_unflatten((tq,), self._out_spec)""", ) self.assertExpectedInline( diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 599397c8aa6f..03a379ecd024 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -641,8 +641,8 @@ def forward(self, primals_1): def forward(self, primals_1, primals_2): mul = torch.ops.aten.mul.Tensor(primals_2, 2) add = torch.ops.aten.add.Tensor(mul, mul) - set_ = torch.ops.aten.set_.source_Tensor(primals_1, mul); primals_1 = None - copy_ = torch.ops.aten.copy_.default(primals_2, mul); primals_2 = mul = None + set_ = torch.ops.aten.set_.source_Tensor(primals_1, mul); primals_1 = set_ = None + copy_ = torch.ops.aten.copy_.default(primals_2, mul); primals_2 = mul = copy_ = None return [add]""", ) @@ -766,11 +766,11 @@ def forward(self, primals_1): view = torch.ops.aten.view.default(arange, [3, 3]); arange = None arange_1 = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) view_1 = torch.ops.aten.view.default(arange_1, [3, 3]); arange_1 = None - set_ = torch.ops.fsdp.set_.default(primals_1, view); view = None + set_ = torch.ops.fsdp.set_.default(primals_1, view); view = set_ = None mul = torch.ops.aten.mul.Tensor(primals_1, primals_1) - set__1 = torch.ops.fsdp.set_.default(primals_1, view_1) + set__1 = torch.ops.fsdp.set_.default(primals_1, view_1); set__1 = None mul_1 = torch.ops.aten.mul.Tensor(primals_1, primals_1) - set__2 = torch.ops.fsdp.set_.default(primals_1, view_1); view_1 = None + set__2 = torch.ops.fsdp.set_.default(primals_1, view_1); view_1 = set__2 = None mul_2 = torch.ops.aten.mul.Tensor(primals_1, primals_1) add = torch.ops.aten.add.Tensor(mul, mul_1); mul = mul_1 = None add_1 = torch.ops.aten.add.Tensor(add, mul_2); add = mul_2 = None @@ -1166,11 +1166,11 @@ def forward(self, arg0_1, arg1_1): fw_graph_cell[0].code.strip(), """\ def forward(self, primals_1): - resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 32) + resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 32); resize_storage_bytes_ = None ones = torch.ops.aten.ones.default([8], device = device(type='cpu'), pin_memory = False) copy = torch.ops.aten.copy.default(primals_1, ones); ones = None add = torch.ops.aten.add.Tensor(copy, 1) - copy_ = torch.ops.aten.copy_.default(primals_1, copy); primals_1 = copy = None + copy_ = torch.ops.aten.copy_.default(primals_1, copy); primals_1 = copy = copy_ = None return [add]""", ) @@ -1203,7 +1203,7 @@ def forward(self, primals_1): """\ def forward(self, primals_1): sin = torch.ops.aten.sin.default(primals_1) - resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 0) + resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 0); resize_storage_bytes_ = None return [sin, primals_1]""", ) @@ -1303,8 +1303,8 @@ def forward(self, primals_1): def forward(self, primals_1, primals_2): cat = torch.ops.aten.cat.default([primals_2, primals_2]); primals_2 = None sin = torch.ops.aten.sin.default(cat) - resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(cat, 0) - set_ = torch.ops.aten.set_.source_Tensor(primals_1, cat); primals_1 = None + resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(cat, 0); resize_storage_bytes_ = None + set_ = torch.ops.aten.set_.source_Tensor(primals_1, cat); primals_1 = set_ = None return [sin, cat]""", ) @@ -1400,7 +1400,7 @@ def forward(self, primals_1): mul = torch.ops.aten.mul.Tensor(view, 2); view = None view_1 = torch.ops.aten.view.default(mul, [4]); mul = None add = torch.ops.aten.add.Tensor(view_1, 1) - copy_ = torch.ops.aten.copy_.default(primals_1, view_1); primals_1 = view_1 = None + copy_ = torch.ops.aten.copy_.default(primals_1, view_1); primals_1 = view_1 = copy_ = None return [add]""", ) @@ -1422,7 +1422,7 @@ def forward(self, primals_1): def forward(self, primals_1): mul = torch.ops.aten.mul.Tensor(primals_1, 2) add = torch.ops.aten.add.Tensor(mul, 3) - copy_ = torch.ops.aten.copy_.default(primals_1, mul); primals_1 = mul = None + copy_ = torch.ops.aten.copy_.default(primals_1, mul); primals_1 = mul = copy_ = None return [add]""", ) @@ -1444,7 +1444,7 @@ def forward(self, primals_1): def forward(self, arg0_1): mul = torch.ops.aten.mul.Tensor(arg0_1, 2) add = torch.ops.aten.add.Tensor(mul, 3) - copy_ = torch.ops.aten.copy_.default(arg0_1, mul); arg0_1 = mul = None + copy_ = torch.ops.aten.copy_.default(arg0_1, mul); arg0_1 = mul = copy_ = None return (add,)""", ) @@ -3609,7 +3609,7 @@ def forward(self, primals_1, primals_2, primals_3, primals_4): sum_1 = torch.ops.aten.sum.default(mul_1); mul_1 = None sum_2 = torch.ops.aten.sum.default(add) add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None - copy_ = torch.ops.aten.copy_.default(primals_3, add); primals_3 = add = None + copy_ = torch.ops.aten.copy_.default(primals_3, add); primals_3 = add = copy_ = None return [add_1, primals_1, primals_2, primals_4, mul]""", ) @@ -3664,7 +3664,7 @@ def forward(self, primals_1, primals_2, primals_3): sum_1 = torch.ops.aten.sum.default(mm); mm = None sum_2 = torch.ops.aten.sum.default(add) add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None - copy_ = torch.ops.aten.copy_.default(primals_2, add); primals_2 = add = None + copy_ = torch.ops.aten.copy_.default(primals_2, add); primals_2 = add = copy_ = None return [add_1, primals_1, primals_3]""", ) self.assertEqual(out_ref, out_test) @@ -3720,9 +3720,9 @@ def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals getitem_2 = _native_batch_norm_legit_functional[2] getitem_3 = _native_batch_norm_legit_functional[3] getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None - copy_ = torch.ops.aten.copy_.default(primals_3, getitem_3); primals_3 = None - copy__1 = torch.ops.aten.copy_.default(primals_4, getitem_4); primals_4 = None - copy__2 = torch.ops.aten.copy_.default(primals_5, add); primals_5 = add = None + copy_ = torch.ops.aten.copy_.default(primals_3, getitem_3); primals_3 = copy_ = None + copy__1 = torch.ops.aten.copy_.default(primals_4, getitem_4); primals_4 = copy__1 = None + copy__2 = torch.ops.aten.copy_.default(primals_5, add); primals_5 = add = copy__2 = None return [getitem, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4]""", # noqa: B950 ) @@ -4076,9 +4076,9 @@ def forward(self, arg0_1, arg1_1): """\ def forward(self, arg0_1, arg1_1): add = torch.ops.aten.add.Tensor(arg1_1, 2) - _set_grad_enabled = torch._C._set_grad_enabled(False) + _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None add_1 = torch.ops.aten.add.Tensor(add, 2); add = None - _set_grad_enabled_1 = torch._C._set_grad_enabled(False) + _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None mul = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None add_2 = torch.ops.aten.add.Tensor(mul, add_1); mul = add_1 = None return (add_2,)""", @@ -4100,9 +4100,9 @@ def forward(self, arg0_1, arg1_1): str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1): - _set_grad_enabled = torch._C._set_grad_enabled(True) + _set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None matmul = torch.ops.aten.matmul.default(arg1_1, arg1_1) - _set_grad_enabled_1 = torch._C._set_grad_enabled(False) + _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None add = torch.ops.aten.add.Tensor(matmul, 2); matmul = None sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None sum_2 = torch.ops.aten.sum.default(add); add = None @@ -4171,9 +4171,9 @@ def forward(self, arg0_1, arg1_1, arg2_1): str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1): - _set_grad_enabled = torch._C._set_grad_enabled(True) + _set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None mm = torch.ops.aten.mm.default(arg1_1, arg1_1) - _set_grad_enabled_1 = torch._C._set_grad_enabled(False) + _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None add = torch.ops.aten.add.Tensor(mm, 2); mm = None sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None sum_2 = torch.ops.aten.sum.default(add); add = None @@ -4257,14 +4257,14 @@ def forward(self, arg0_1, arg1_1): str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1): - _set_grad_enabled = torch._C._set_grad_enabled(True) + _set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None add = torch.ops.aten.add.Tensor(arg1_1, 5) add_1 = torch.ops.aten.add.Tensor(add, 5); add = None add_2 = torch.ops.aten.add.Tensor(add_1, 7); add_1 = None cos = torch.ops.aten.cos.default(arg1_1); arg1_1 = None sin = torch.ops.aten.sin.default(add_2); add_2 = None add_3 = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None - _set_grad_enabled_1 = torch._C._set_grad_enabled(False) + _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None return (add_3,)""", ) @@ -4410,13 +4410,13 @@ def forward(self, arg0_1, arg1_1): """\ def forward(self, arg0_1, arg1_1): cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None - select = torch.ops.aten.select.int(cos, 0, 0) + select = torch.ops.aten.select.int(cos, 0, 0); select = None body_graph_0 = self.body_graph_0 map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = None getitem = map_impl[0]; map_impl = None sum_1 = torch.ops.aten.sum.default(getitem); getitem = None add = torch.ops.aten.add.Tensor(cos, sum_1); sum_1 = None - select_1 = torch.ops.aten.select.int(cos, 0, 0) + select_1 = torch.ops.aten.select.int(cos, 0, 0); select_1 = None body_graph_1 = self.body_graph_1 map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]); body_graph_1 = cos = arg1_1 = None getitem_1 = map_impl_1[0]; map_impl_1 = None @@ -4633,7 +4633,7 @@ class (torch.nn.Module): getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None - detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu) + detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); detach = None detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu) detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_2); detach_2 = None @@ -4657,7 +4657,7 @@ class (torch.nn.Module): getitem_6: "f32[3]" = native_batch_norm_backward[1] getitem_7: "f32[3]" = native_batch_norm_backward[2]; native_batch_norm_backward = None convolution_backward = torch.ops.aten.convolution_backward.default(getitem_5, arg7_1, arg0_1, [3], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]); getitem_5 = arg7_1 = arg0_1 = None - getitem_8 = convolution_backward[0] + getitem_8 = convolution_backward[0]; getitem_8 = None getitem_9: "f32[3, 1, 1, 1]" = convolution_backward[1] getitem_10: "f32[3]" = convolution_backward[2]; convolution_backward = None return (getitem_3, getitem_4, add, sum_1, detach_10, getitem_9, getitem_10, getitem_6, getitem_7) @@ -4954,7 +4954,7 @@ def forward(self, arg0_1): """\ def forward(self, arg0_1): add = torch.ops.aten.add.Tensor(arg0_1, 4) - add_1 = torch.ops.aten.add.Tensor(add, 5); add = None + add_1 = torch.ops.aten.add.Tensor(add, 5); add = add_1 = None cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None return (cos,)""", ) @@ -4964,7 +4964,7 @@ def forward(self, arg0_1): """\ def forward(self, arg0_1): add = torch.ops.aten.add.Tensor(arg0_1, 5) - add_1 = torch.ops.aten.add.Tensor(add, 6); add = None + add_1 = torch.ops.aten.add.Tensor(add, 6); add = add_1 = None sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None return (sin,)""", ) diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index b2e35a68392a..1f5619d578ca 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -446,7 +446,7 @@ def forward(self, pred_1, x_1, y_1, z_1): false_graph_1 = self.false_graph_1 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, z_1, y_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = z_1 = y_1 = None getitem_1 = cond_1[0] - getitem_2 = cond_1[1]; cond_1 = None + getitem_2 = cond_1[1]; cond_1 = getitem_2 = None return (getitem_1,)""", # noqa: B950 ) @@ -505,10 +505,10 @@ def forward(self, pred_1, x_1): _param_constant1_1 = self._param_constant1 _tensor_constant0_1 = self._tensor_constant0 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _param_constant0_1, _param_constant1_1, x_1, _tensor_constant0_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _param_constant0_1 = _param_constant1_1 = x_1 = _tensor_constant0_1 = None - getitem_1 = cond_1[0] + getitem_1 = cond_1[0]; getitem_1 = None getitem_2 = cond_1[1] - getitem_3 = cond_1[2] - getitem_4 = cond_1[3]; cond_1 = None + getitem_3 = cond_1[2]; getitem_3 = None + getitem_4 = cond_1[3]; cond_1 = getitem_4 = None return (getitem_2,)""", # noqa: B950 ) @@ -621,7 +621,7 @@ def forward(self, pred_1, a_1, b_1, c_1): cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, a_1, b_1, c_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = a_1 = b_1 = c_1 = None getitem_1 = cond_1[0] getitem_2 = cond_1[1] - getitem_3 = cond_1[2]; cond_1 = None + getitem_3 = cond_1[2]; cond_1 = getitem_3 = None return (getitem_1, getitem_2)""", # noqa: B950 ) # Forward @@ -637,7 +637,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): gm.true_graph_1.code.strip(), """\ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1): - add = torch.ops.aten.add.Tensor(arg1_1, arg2_1); arg1_1 = arg2_1 = None + add = torch.ops.aten.add.Tensor(arg1_1, arg2_1); arg1_1 = arg2_1 = add = None clone = torch.ops.aten.clone.default(arg0_1) clone_1 = torch.ops.aten.clone.default(arg0_1); arg0_1 = None return [clone, clone_1, None]""", @@ -695,7 +695,7 @@ def forward(self, pred_1): cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _tensor_constant0_1, _tensor_constant1_1, _tensor_constant2_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _tensor_constant0_1 = _tensor_constant1_1 = _tensor_constant2_1 = None getitem_1 = cond_1[0] getitem_2 = cond_1[1] - getitem_3 = cond_1[2]; cond_1 = None + getitem_3 = cond_1[2]; cond_1 = getitem_3 = None return (getitem_1, getitem_2)""", # noqa: B950 ) @@ -823,12 +823,12 @@ def forward(self, pred_1, x_1): _param_constant5_1 = self._param_constant5 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1, _param_constant0_1, _param_constant1_1, _param_constant2_1, _param_constant3_1, _param_constant4_1, _param_constant5_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = _param_constant0_1 = _param_constant1_1 = _param_constant2_1 = _param_constant3_1 = _param_constant4_1 = _param_constant5_1 = None getitem_1 = cond_1[0] - getitem_2 = cond_1[1] - getitem_3 = cond_1[2] - getitem_4 = cond_1[3] - getitem_5 = cond_1[4] - getitem_6 = cond_1[5] - getitem_7 = cond_1[6]; cond_1 = None + getitem_2 = cond_1[1]; getitem_2 = None + getitem_3 = cond_1[2]; getitem_3 = None + getitem_4 = cond_1[3]; getitem_4 = None + getitem_5 = cond_1[4]; getitem_5 = None + getitem_6 = cond_1[5]; getitem_6 = None + getitem_7 = cond_1[6]; cond_1 = getitem_7 = None return (getitem_1,)""", # noqa: B950 ) @@ -1893,7 +1893,7 @@ def forward(self, x_1): view_2 = torch.ops.aten.view.default(view_1, [4, 5]) sin = torch.ops.aten.sin.default(view_2); view_2 = None sum_1 = torch.ops.aten.sum.default(sin); sin = None - copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = None + copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = copy_ = None return sum_1""", ) @@ -1934,7 +1934,7 @@ def forward(self, x_1): view_2 = torch.ops.aten.view.default(view_1, [5, 5]) cos = torch.ops.aten.cos.default(view_2); view_2 = None sum_1 = torch.ops.aten.sum.default(cos); cos = None - copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = None + copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = copy_ = None return sum_1""", ) @@ -3495,10 +3495,10 @@ def forward(self, l_inp_, l_tmp_): a = l_inp__1.clone(); l_inp__1 = None a_view = a.view(-1) tmp = l_tmp__1.clone(); l_tmp__1 = None - _set_grad_enabled = torch._C._set_grad_enabled(False) - set_ = a.set_(tmp) - mul_ = a_view.mul_(2); a_view = None - _set_grad_enabled_1 = torch._C._set_grad_enabled(True) + _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None + set_ = a.set_(tmp); set_ = None + mul_ = a_view.mul_(2); a_view = mul_ = None + _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None add = a + tmp; a = tmp = None return (add,) """, diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 4a68b394be45..21ff2edf2392 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -4756,8 +4756,8 @@ def forward(self, x_1) -> torch.Tensor: view_copy = torch.ops.aten.view_copy.default(x_1, [4, 2]) add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None - view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2]) - copy_ = torch.ops.aten.copy_.default(x_1, view_copy_1); x_1 = None + view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2]); view_copy_2 = None + copy_ = torch.ops.aten.copy_.default(x_1, view_copy_1); x_1 = copy_ = None return view_copy_1 """, ) @@ -4799,13 +4799,13 @@ def forward(self, x_1) -> torch.Tensor: def forward(self, inpt_1) -> torch.Tensor: - empty = torch.ops.aten.empty.memory_format([], dtype = torch.float32, device = 'cpu', pin_memory = False) + empty = torch.ops.aten.empty.memory_format([], dtype = torch.float32, device = 'cpu', pin_memory = False); empty = None add = torch.ops.aten.add.Tensor(inpt_1, inpt_1); inpt_1 = None - view_copy = torch.ops.aten.view_copy.default(add, [4]) + view_copy = torch.ops.aten.view_copy.default(add, [4]); view_copy = None view_copy_1 = torch.ops.aten.view_copy.default(add, [4]); add = None add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4]); add_1 = None - view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [4]) + view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [4]); view_copy_3 = None return view_copy_2 """, ) @@ -4829,15 +4829,15 @@ def forward(self, inpt_1) -> torch.Tensor: def forward(self, inpt_1) -> torch.Tensor: - empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = 'cpu', pin_memory = False) + empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = 'cpu', pin_memory = False); empty = None empty_1 = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = 'cpu', pin_memory = False) - view_copy = torch.ops.aten.view_copy.default(empty_1, [4]); empty_1 = None + view_copy = torch.ops.aten.view_copy.default(empty_1, [4]); empty_1 = view_copy = None view_copy_1 = torch.ops.aten.view_copy.default(inpt_1, [2, 4]); inpt_1 = None aminmax = torch.ops.aten.aminmax.default(view_copy_1, dim = 0); view_copy_1 = None getitem = aminmax[0] getitem_1 = aminmax[1]; aminmax = None view_copy_2 = torch.ops.aten.view_copy.default(getitem_1, [2, 2]); getitem_1 = None - view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [4]) + view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [4]); view_copy_3 = None return (view_copy_2, getitem) """, ) @@ -4862,8 +4862,8 @@ def forward(self, x_1) -> torch.Tensor: view = torch.ops.aten.view.default(x_1, [4, 2]) add = torch.ops.aten.add.Tensor(view, ones); view = ones = None view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None - view_2 = torch.ops.aten.view.default(view_1, [4, 2]) - copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = None + view_2 = torch.ops.aten.view.default(view_1, [4, 2]); view_2 = None + copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = copy_ = None return view_1 """, ) @@ -4952,7 +4952,7 @@ def forward(self, x_1): resize = torch.ops.aten.resize.default(x_1, [10]) fill = torch.ops.aten.fill.Scalar(resize, 2); resize = None resize_ = torch.ops.aten.resize_.default(x_1, [10]); x_1 = None - copy_ = torch.ops.aten.copy_.default(resize_, fill); resize_ = fill = None + copy_ = torch.ops.aten.copy_.default(resize_, fill); resize_ = fill = copy_ = None return None """, ) diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index a3c5f1bb1466..9ca42b15e377 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -76,7 +76,7 @@ def forward(self, arg1_1): add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None getitem_2 = with_effects_1[0]; with_effects_1 = None - _sink_tokens_default = torch.ops.prims._sink_tokens.default((getitem_2,)); getitem_2 = None + _sink_tokens_default = torch.ops.prims._sink_tokens.default((getitem_2,)); getitem_2 = _sink_tokens_default = None return (add,)""", # noqa: B950 ) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 1cae48de2799..22944e4acad2 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1535,16 +1535,16 @@ class GraphModule(torch.nn.Module): l_kwargs_block_mask_full_q_num_blocks = L_kwargs_block_mask_full_q_num_blocks l_kwargs_block_mask_full_q_indices = L_kwargs_block_mask_full_q_indices - child_1: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32) - child_2: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32) - child_3: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32) - child_4: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32) - child: "f64[]" = l_args_0_.new_empty([], requires_grad = True) + child_1: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32); child_1 = None + child_2: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32); child_2 = None + child_3: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32); child_3 = None + child_4: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32); child_4 = None + child: "f64[]" = l_args_0_.new_empty([], requires_grad = True); child = None score_mod_0 = self.score_mod_0 - child_5: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32) - child_6: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32) - child_7: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32) - child_8: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32) + child_5: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32); child_5 = None + child_6: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32); child_6 = None + child_7: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32); child_7 = None + child_8: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32); child_8 = None mask_fn_0 = self.mask_fn_0 flex_attention = torch.ops.higher_order.flex_attention(l_args_0_, l_args_1_, l_args_2_, score_mod_0, (l_kwargs_block_mask_kv_num_blocks, l_kwargs_block_mask_kv_indices, l_kwargs_block_mask_full_kv_num_blocks, l_kwargs_block_mask_full_kv_indices, l_kwargs_block_mask_q_num_blocks, l_kwargs_block_mask_q_indices, l_kwargs_block_mask_full_q_num_blocks, l_kwargs_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, (), ()); l_args_0_ = l_args_1_ = l_args_2_ = score_mod_0 = l_kwargs_block_mask_kv_num_blocks = l_kwargs_block_mask_kv_indices = l_kwargs_block_mask_full_kv_num_blocks = l_kwargs_block_mask_full_kv_indices = l_kwargs_block_mask_q_num_blocks = l_kwargs_block_mask_q_indices = l_kwargs_block_mask_full_q_num_blocks = l_kwargs_block_mask_full_q_indices = mask_fn_0 = None out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None @@ -1599,7 +1599,7 @@ class GraphModule(torch.nn.Module): class (torch.nn.Module): def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]", arg5_1: "f64[]"): - mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1) + mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); mul = None mul_1: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1) mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1); arg5_1 = arg0_1 = None add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None diff --git a/test/test_functionalization.py b/test/test_functionalization.py index 978b58b492c0..3626145d77ce 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -249,22 +249,22 @@ def forward(self, arg0_1): relu = torch.ops.aten.relu.default(view_copy_1); view_copy_1 = None view_copy_2 = torch.ops.aten.view_copy.default(relu, [1, 1024, 128, 128]); relu = None view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [16, 64, 128, 128]); view_copy_2 = None - view_copy_4 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]); clone = None + view_copy_4 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]); clone = view_copy_4 = None sum_1 = torch.ops.aten.sum.default(view_copy_3) ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None view_copy_5 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]); expand_copy = None new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_5, [1, 1024, 128, 128], [16777216, 16384, 128, 1]) copy = torch.ops.aten.copy.default(new_empty_strided, view_copy_5); new_empty_strided = view_copy_5 = None - view_copy_6 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]) + view_copy_6 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]); view_copy_6 = None view_copy_7 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]) clone_1 = torch.ops.aten.clone.default(view_copy_7, memory_format = torch.contiguous_format) threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, view_copy_3, 0); clone_1 = view_copy_3 = None copy_1 = torch.ops.aten.copy.default(view_copy_7, threshold_backward); view_copy_7 = threshold_backward = None view_copy_8 = torch.ops.aten.view_copy.default(copy_1, [1, 1024, 128, 128]); copy_1 = None - view_copy_9 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]) + view_copy_9 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]); view_copy_9 = None view_copy_10 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]); copy = None - detach_copy = torch.ops.aten.detach_copy.default(view_copy_10); view_copy_10 = None + detach_copy = torch.ops.aten.detach_copy.default(view_copy_10); view_copy_10 = detach_copy = None view_copy_11 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]); view_copy_8 = None detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_11); view_copy_11 = None return detach_copy_1 @@ -294,8 +294,8 @@ def forward(self, arg0_1): add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2]) - mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1) - copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None + mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1); mul = None + copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = copy_ = None return view_copy_2 """, ) @@ -315,8 +315,8 @@ def forward(self, arg0_1): add = torch.ops.aten.add.Tensor(view, ones); view = ones = None view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None view_2 = torch.ops.aten.view.default(view_1, [4, 2]) - mul = torch.ops.aten.mul.Tensor(view_1, view_1) - copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = None + mul = torch.ops.aten.mul.Tensor(view_1, view_1); mul = None + copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = copy_ = None return view_2 """, ) @@ -342,7 +342,7 @@ def forward(self, arg0_1): def forward(self, arg0_1): ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); arg0_1 = None - empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False); empty = None add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None mul = torch.ops.aten.mul.Tensor(add, add); add = None return mul @@ -361,7 +361,7 @@ def forward(self, arg0_1): def forward(self, arg0_1): ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) view = torch.ops.aten.view.default(arg0_1, [4, 2]); arg0_1 = None - empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False); empty = None add = torch.ops.aten.add.Tensor(view, ones); view = ones = None mul = torch.ops.aten.mul.Tensor(add, add); add = None return mul @@ -386,11 +386,11 @@ def forward(self, arg0_1): def forward(self, arg0_1): - empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False) - empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False); empty = None + empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False); empty_1 = None aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0); arg0_1 = None getitem = aminmax[0] - getitem_1 = aminmax[1]; aminmax = None + getitem_1 = aminmax[1]; aminmax = getitem_1 = None return getitem """, ) @@ -408,11 +408,11 @@ def forward(self, arg0_1): def forward(self, arg0_1): - empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False) - empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False); empty = None + empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False); empty_1 = None aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0); arg0_1 = None getitem = aminmax[0] - getitem_1 = aminmax[1]; aminmax = None + getitem_1 = aminmax[1]; aminmax = getitem_1 = None return getitem """, ) @@ -440,7 +440,7 @@ def forward(self, arg0_1): view_copy = torch.ops.aten.view_copy.default(lift_fresh_copy, [-1]); lift_fresh_copy = None add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None view_copy_1 = torch.ops.aten.view_copy.default(add, [3]); add = None - view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [-1]) + view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [-1]); view_copy_2 = None return view_copy_1 """, ) @@ -456,9 +456,9 @@ def forward(self, arg0_1): _tensor_constant0 = self._tensor_constant0 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None view = torch.ops.aten.view.default(lift_fresh_copy, [-1]); lift_fresh_copy = None - add = torch.ops.aten.add_.Tensor(view, 1) + add = torch.ops.aten.add_.Tensor(view, 1); add = None view_1 = torch.ops.aten.view.default(view, [3]); view = None - view_2 = torch.ops.aten.view.default(view_1, [-1]) + view_2 = torch.ops.aten.view.default(view_1, [-1]); view_2 = None return view_1 """, ) @@ -508,9 +508,9 @@ def forward(self, arg0_1): def forward(self, arg0_1): ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) - view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]) + view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); view_copy = None add = torch.ops.aten.add.Tensor(arg0_1, ones); ones = None - copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = copy_ = None view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None return view_copy_1 """, @@ -527,9 +527,9 @@ def forward(self, arg0_1): def forward(self, arg0_1): ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) - view = torch.ops.aten.view.default(arg0_1, [4, 2]) + view = torch.ops.aten.view.default(arg0_1, [4, 2]); view = None add = torch.ops.aten.add.Tensor(arg0_1, ones); ones = None - copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = copy_ = None view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None return view_1 """, @@ -554,11 +554,11 @@ def forward(self, arg0_1): _fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0) getitem = _fused_moving_avg_obs_fq_helper_functional[0] getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1] - getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2] - getitem_3 = _fused_moving_avg_obs_fq_helper_functional[3] - getitem_4 = _fused_moving_avg_obs_fq_helper_functional[4] + getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2]; getitem_2 = None + getitem_3 = _fused_moving_avg_obs_fq_helper_functional[3]; getitem_3 = None + getitem_4 = _fused_moving_avg_obs_fq_helper_functional[4]; getitem_4 = None getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5]; _fused_moving_avg_obs_fq_helper_functional = None - copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_5); arg0_1 = getitem_5 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_5); arg0_1 = getitem_5 = copy_ = None return (getitem, getitem_1) """, # noqa: B950 ) @@ -581,8 +581,8 @@ def forward(self, arg0_1): as_strided_copy = torch.ops.aten.as_strided_copy.default(arg0_1, [2], [2], 1) add = torch.ops.aten.add.Tensor(as_strided_copy, 1); as_strided_copy = None as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1); add = None - as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [2], [2], 1) - copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = None + as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [2], [2], 1); as_strided_copy_1 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = copy_ = None return as_strided_scatter """, ) @@ -601,8 +601,8 @@ def forward(self, arg0_1): as_strided = torch.ops.aten.as_strided.default(arg0_1, [2], [2], 1) add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1); add = None - as_strided_1 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [2], 1) - copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = None + as_strided_1 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [2], 1); as_strided_1 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = copy_ = None return as_strided_scatter """, ) @@ -642,7 +642,7 @@ def forward(self, arg0_1): def forward(self, arg0_1): - empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False); empty = None cat = torch.ops.aten.cat.default([arg0_1]); arg0_1 = None return cat """, @@ -658,7 +658,7 @@ def forward(self, arg0_1): def forward(self, arg0_1): - empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False); empty = None cat = torch.ops.aten.cat.default([arg0_1]); arg0_1 = None return cat """, @@ -687,7 +687,7 @@ def forward(self, arg0_1): diagonal_copy = torch.ops.aten.diagonal_copy.default(clone) add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None diagonal_scatter = torch.ops.aten.diagonal_scatter.default(clone, add); clone = add = None - diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter); diagonal_scatter = None + diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter); diagonal_scatter = diagonal_copy_1 = None mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None return mul """, @@ -706,8 +706,8 @@ def forward(self, arg0_1): ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) clone = torch.ops.aten.clone.default(arg0_1) diagonal = torch.ops.aten.diagonal.default(clone) - add = torch.ops.aten.add_.Tensor(diagonal, ones); diagonal = ones = None - diagonal_1 = torch.ops.aten.diagonal.default(clone); clone = None + add = torch.ops.aten.add_.Tensor(diagonal, ones); diagonal = ones = add = None + diagonal_1 = torch.ops.aten.diagonal.default(clone); clone = diagonal_1 = None mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None return mul """, @@ -735,8 +735,8 @@ def forward(self, arg0_1): diagonal_copy = torch.ops.aten.diagonal_copy.default(arg0_1) add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add); add = None - diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter) - copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = None + diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter); diagonal_copy_1 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = copy_ = None return diagonal_scatter """, ) @@ -756,8 +756,8 @@ def forward(self, arg0_1): diagonal = torch.ops.aten.diagonal.default(arg0_1) add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add); add = None - diagonal_1 = torch.ops.aten.diagonal.default(diagonal_scatter) - copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = None + diagonal_1 = torch.ops.aten.diagonal.default(diagonal_scatter); diagonal_1 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = copy_ = None return diagonal_scatter """, ) @@ -802,21 +802,21 @@ def forward(self, arg0_1): def forward(self, arg0_1): ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) split_copy = torch.ops.aten.split_copy.Tensor(arg0_1, 2) - getitem = split_copy[0] + getitem = split_copy[0]; getitem = None getitem_1 = split_copy[1]; split_copy = None diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None split_copy_1 = torch.ops.aten.split_copy.Tensor(arg0_1, 2) - getitem_2 = split_copy_1[0] + getitem_2 = split_copy_1[0]; getitem_2 = None getitem_3 = split_copy_1[1]; split_copy_1 = None diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = add = None slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None split_copy_2 = torch.ops.aten.split_copy.Tensor(slice_scatter, 2) - getitem_4 = split_copy_2[0] + getitem_4 = split_copy_2[0]; getitem_4 = None getitem_5 = split_copy_2[1]; split_copy_2 = None diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_5); getitem_5 = None - mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter) - copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None + mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter); mul = None + copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None return diagonal_copy_1 """, ) # noqa: B950 @@ -834,21 +834,21 @@ def forward(self, arg0_1): def forward(self, arg0_1): ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) split = torch.ops.aten.split.Tensor(arg0_1, 2) - getitem = split[0] + getitem = split[0]; getitem = None getitem_1 = split[1]; split = None diagonal = torch.ops.aten.diagonal.default(getitem_1); getitem_1 = None add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None split_1 = torch.ops.aten.split.Tensor(arg0_1, 2) - getitem_2 = split_1[0] + getitem_2 = split_1[0]; getitem_2 = None getitem_3 = split_1[1]; split_1 = None diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = add = None slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None split_2 = torch.ops.aten.split.Tensor(slice_scatter, 2) - getitem_4 = split_2[0] + getitem_4 = split_2[0]; getitem_4 = None getitem_5 = split_2[1]; split_2 = None diagonal_1 = torch.ops.aten.diagonal.default(getitem_5); getitem_5 = None - mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter) - copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None + mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter); mul = None + copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None return diagonal_1 """, ) # noqa: B950 @@ -875,20 +875,20 @@ def forward(self, arg0_1): ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) split_with_sizes_copy = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2]) getitem = split_with_sizes_copy[0] - getitem_1 = split_with_sizes_copy[1]; split_with_sizes_copy = None + getitem_1 = split_with_sizes_copy[1]; split_with_sizes_copy = getitem_1 = None diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem); getitem = None add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None split_with_sizes_copy_1 = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2]) getitem_2 = split_with_sizes_copy_1[0] - getitem_3 = split_with_sizes_copy_1[1]; split_with_sizes_copy_1 = None + getitem_3 = split_with_sizes_copy_1[1]; split_with_sizes_copy_1 = getitem_3 = None diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add); getitem_2 = add = None slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2); diagonal_scatter = None split_with_sizes_copy_2 = torch.ops.aten.split_with_sizes_copy.default(slice_scatter, [2, 2]) getitem_4 = split_with_sizes_copy_2[0] - getitem_5 = split_with_sizes_copy_2[1]; split_with_sizes_copy_2 = None + getitem_5 = split_with_sizes_copy_2[1]; split_with_sizes_copy_2 = getitem_5 = None diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_4); getitem_4 = None - mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter) - copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None + mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter); mul = None + copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None return diagonal_copy_1 """, ) # noqa: B950 @@ -907,20 +907,20 @@ def forward(self, arg0_1): ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) split_with_sizes = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2]) getitem = split_with_sizes[0] - getitem_1 = split_with_sizes[1]; split_with_sizes = None + getitem_1 = split_with_sizes[1]; split_with_sizes = getitem_1 = None diagonal = torch.ops.aten.diagonal.default(getitem); getitem = None add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2]) getitem_2 = split_with_sizes_1[0] - getitem_3 = split_with_sizes_1[1]; split_with_sizes_1 = None + getitem_3 = split_with_sizes_1[1]; split_with_sizes_1 = getitem_3 = None diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add); getitem_2 = add = None slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2); diagonal_scatter = None split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(slice_scatter, [2, 2]) getitem_4 = split_with_sizes_2[0] - getitem_5 = split_with_sizes_2[1]; split_with_sizes_2 = None + getitem_5 = split_with_sizes_2[1]; split_with_sizes_2 = getitem_5 = None diagonal_1 = torch.ops.aten.diagonal.default(getitem_4); getitem_4 = None - mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter) - copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None + mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter); mul = None + copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None return diagonal_1 """, ) # noqa: B950 @@ -950,7 +950,7 @@ def forward(self, arg0_1): slice_scatter = torch.ops.aten.slice_scatter.default(transpose_copy_1, add, 0, 0, 2); transpose_copy_1 = add = None transpose_copy_2 = torch.ops.aten.transpose_copy.int(slice_scatter, 1, 0); slice_scatter = None transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0) - slice_copy_1 = torch.ops.aten.slice_copy.Tensor(transpose_copy_3, 0, 0, 2); transpose_copy_3 = None + slice_copy_1 = torch.ops.aten.slice_copy.Tensor(transpose_copy_3, 0, 0, 2); transpose_copy_3 = slice_copy_1 = None transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None return transpose_copy_4 """, @@ -975,7 +975,7 @@ def forward(self, arg0_1): slice_scatter = torch.ops.aten.slice_scatter.default(transpose_1, add, 0, 0, 2); transpose_1 = add = None transpose_2 = torch.ops.aten.transpose.int(slice_scatter, 1, 0); slice_scatter = None transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0) - slice_2 = torch.ops.aten.slice.Tensor(transpose_3, 0, 0, 2); transpose_3 = None + slice_2 = torch.ops.aten.slice.Tensor(transpose_3, 0, 0, 2); transpose_3 = slice_2 = None transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None return transpose_4 """, @@ -1007,7 +1007,7 @@ def forward(self, arg0_1): select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0) - select_copy_1 = torch.ops.aten.select_copy.int(transpose_copy_3, 0, 0); transpose_copy_3 = None + select_copy_1 = torch.ops.aten.select_copy.int(transpose_copy_3, 0, 0); transpose_copy_3 = select_copy_1 = None transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None return transpose_copy_4 """, @@ -1032,7 +1032,7 @@ def forward(self, arg0_1): select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0); transpose_1 = add = None transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0); select_scatter = None transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0) - select_1 = torch.ops.aten.select.int(transpose_3, 0, 0); transpose_3 = None + select_1 = torch.ops.aten.select.int(transpose_3, 0, 0); transpose_3 = select_1 = None transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None return transpose_4 """, @@ -1060,15 +1060,15 @@ def forward(self, arg0_1): transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0) unbind_copy = torch.ops.aten.unbind_copy.int(transpose_copy); transpose_copy = None getitem = unbind_copy[0] - getitem_1 = unbind_copy[1]; unbind_copy = None + getitem_1 = unbind_copy[1]; unbind_copy = getitem_1 = None add = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0) unbind_copy_1 = torch.ops.aten.unbind_copy.int(transpose_copy_3); transpose_copy_3 = None - getitem_2 = unbind_copy_1[0] - getitem_3 = unbind_copy_1[1]; unbind_copy_1 = None + getitem_2 = unbind_copy_1[0]; getitem_2 = None + getitem_3 = unbind_copy_1[1]; unbind_copy_1 = getitem_3 = None transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None return transpose_copy_4 """, @@ -1089,15 +1089,15 @@ def forward(self, arg0_1): transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0) unbind = torch.ops.aten.unbind.int(transpose); transpose = None getitem = unbind[0] - getitem_1 = unbind[1]; unbind = None + getitem_1 = unbind[1]; unbind = getitem_1 = None add = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0); transpose_1 = add = None transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0); select_scatter = None transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0) unbind_1 = torch.ops.aten.unbind.int(transpose_3); transpose_3 = None - getitem_2 = unbind_1[0] - getitem_3 = unbind_1[1]; unbind_1 = None + getitem_2 = unbind_1[0]; getitem_2 = None + getitem_3 = unbind_1[1]; unbind_1 = getitem_3 = None transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None return transpose_4 """, @@ -1128,7 +1128,7 @@ def forward(self, arg0_1): index_put = torch.ops.aten.index_put.default(view_copy, [arange], arange_1); view_copy = arange = arange_1 = None view_copy_1 = torch.ops.aten.view_copy.default(index_put, [4, 2]); index_put = None view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [8]) - copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = copy_ = None return view_copy_2 """, ) # noqa: B950 @@ -1152,14 +1152,14 @@ def forward(self, arg0_1): def forward(self, arg0_1): - ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) + ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False); ones = None view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]) add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2]) mul = torch.ops.aten.mul.Tensor(view_copy_2, 2); view_copy_2 = None div = torch.ops.aten.div.Tensor(mul, 1); mul = None - copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = copy_ = None return div """, ) @@ -1278,7 +1278,7 @@ def forward(self, arg0_1): squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy); unsqueeze_copy = None split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2); squeeze_copy = None getitem = split_copy[0] - getitem_1 = split_copy[1]; split_copy = None + getitem_1 = split_copy[1]; split_copy = getitem_1 = None add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None view_copy_2 = torch.ops.aten.view_copy.default(add, [8]); add = None view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [2, 4]); view_copy_2 = None @@ -1298,9 +1298,9 @@ def forward(self, arg0_1): squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3); unsqueeze_copy_3 = None split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2); squeeze_copy_3 = None getitem_2 = split_copy_1[0] - getitem_3 = split_copy_1[1]; split_copy_1 = None - select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0); view_copy_1 = None - view_copy_8 = torch.ops.aten.view_copy.default(getitem_2, [4]) + getitem_3 = split_copy_1[1]; split_copy_1 = getitem_3 = None + select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0); view_copy_1 = select_copy = None + view_copy_8 = torch.ops.aten.view_copy.default(getitem_2, [4]); view_copy_8 = None view_copy_9 = torch.ops.aten.view_copy.default(view_copy_5, [8]) view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]); view_copy_9 = None select_copy_1 = torch.ops.aten.select_copy.int(view_copy_10, 0, 0); view_copy_10 = None @@ -1311,9 +1311,9 @@ def forward(self, arg0_1): squeeze_copy_4 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_4); unsqueeze_copy_4 = None split_copy_2 = torch.ops.aten.split_copy.Tensor(squeeze_copy_4, 2); squeeze_copy_4 = None getitem_4 = split_copy_2[0] - getitem_5 = split_copy_2[1]; split_copy_2 = None + getitem_5 = split_copy_2[1]; split_copy_2 = getitem_5 = None view_copy_13 = torch.ops.aten.view_copy.default(getitem_4, [4]); getitem_4 = None - add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_13); select_copy_1 = view_copy_13 = None + add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_13); select_copy_1 = view_copy_13 = add_2 = None return getitem_2 """, ) # noqa: B950 @@ -1337,8 +1337,8 @@ def forward(self, arg0_1): squeeze = torch.ops.aten.squeeze.default(unsqueeze); unsqueeze = None split = torch.ops.aten.split.Tensor(squeeze, 2); squeeze = None getitem = split[0] - getitem_1 = split[1]; split = None - add_1 = torch.ops.aten.add_.Tensor(getitem, ones); getitem = ones = None + getitem_1 = split[1]; split = getitem_1 = None + add_1 = torch.ops.aten.add_.Tensor(getitem, ones); getitem = ones = add_1 = None view_2 = torch.ops.aten.view.default(add, [8]); add = None view_3 = torch.ops.aten.view.default(view_2, [2, 4]); view_2 = None transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0); view_3 = None @@ -1356,14 +1356,14 @@ def forward(self, arg0_1): squeeze_3 = torch.ops.aten.squeeze.default(unsqueeze_3); unsqueeze_3 = None split_1 = torch.ops.aten.split.Tensor(squeeze_3, 2); squeeze_3 = None getitem_2 = split_1[0] - getitem_3 = split_1[1]; split_1 = None - select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None + getitem_3 = split_1[1]; split_1 = getitem_3 = None + select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = select = None clone = torch.ops.aten.clone.default(getitem_2, memory_format = torch.contiguous_format) _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None view_8 = torch.ops.aten.view.default(view_5, [8]); view_5 = None view_9 = torch.ops.aten.view.default(view_8, [2, 4]); view_8 = None select_1 = torch.ops.aten.select.int(view_9, 0, 0); view_9 = None - add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = None + add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = add_2 = None return getitem_2 """, ) @@ -1390,8 +1390,8 @@ def forward(self, arg0_1): add = torch.ops.aten.add.Tensor(view, ones); view = ones = None view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None view_2 = torch.ops.aten.view.default(view_1, [4, 2]) - mul = torch.ops.aten.mul.Tensor(view_1, view_1) - copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = None + mul = torch.ops.aten.mul.Tensor(view_1, view_1); mul = None + copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = copy_ = None return view_2 """, ) @@ -1463,9 +1463,9 @@ def forward(self, arg0_1): def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal = torch.ops.aten.diagonal.default(zeros) - copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None + copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = copy = None diagonal_1 = torch.ops.aten.diagonal.default(zeros) - add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None + add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = add = None diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None return diagonal_2 """, @@ -1505,9 +1505,9 @@ def forward(self, arg0_1): def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal = torch.ops.aten.diagonal.default(zeros) - copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None + copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = copy = None diagonal_1 = torch.ops.aten.diagonal.default(zeros) - add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None + add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = add = None diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None return diagonal_2 """, @@ -1547,9 +1547,9 @@ def forward(self, arg0_1): def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal = torch.ops.aten.diagonal.default(zeros) - copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None + copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = copy = None diagonal_1 = torch.ops.aten.diagonal.default(zeros) - add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None + add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = add = None diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None return diagonal_2 """, @@ -1589,9 +1589,9 @@ def forward(self, arg0_1): def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal = torch.ops.aten.diagonal.default(zeros) - copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None + copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = copy = None diagonal_1 = torch.ops.aten.diagonal.default(zeros) - add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None + add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = add = None diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None return diagonal_2 """, @@ -1637,7 +1637,7 @@ def forward(self, arg0_1): diagonal_copy = torch.ops.aten.diagonal_copy.default(add) fill = torch.ops.aten.fill.Scalar(diagonal_copy, 0); diagonal_copy = None diagonal_scatter = torch.ops.aten.diagonal_scatter.default(add, fill); add = fill = None - diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter) + diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter); diagonal_copy_1 = None return diagonal_scatter """, ) @@ -1654,8 +1654,8 @@ def forward(self, arg0_1): def forward(self, arg0_1): add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None diagonal = torch.ops.aten.diagonal.default(add) - fill = torch.ops.aten.fill_.Scalar(diagonal, 0); diagonal = None - diagonal_1 = torch.ops.aten.diagonal.default(add) + fill = torch.ops.aten.fill_.Scalar(diagonal, 0); diagonal = fill = None + diagonal_1 = torch.ops.aten.diagonal.default(add); diagonal_1 = None return add """, ) @@ -1682,18 +1682,18 @@ def forward(self, arg0_1): def forward(self, arg0_1): add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None view_copy = torch.ops.aten.view_copy.default(add, [4, 4]) - resize = torch.ops.aten.resize.default(view_copy, [3, 3]) + resize = torch.ops.aten.resize.default(view_copy, [3, 3]); resize = None as_strided_copy = torch.ops.aten.as_strided_copy.default(view_copy, [3, 3], [3, 1]); view_copy = None view_copy_1 = torch.ops.aten.view_copy.default(as_strided_copy, [-1]); as_strided_copy = None add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]); add = None - as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1]) + as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1]); as_strided_copy_1 = None view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]); add_1 = None as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]); as_strided_scatter = None view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]) as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]); view_copy_5 = None - view_copy_6 = torch.ops.aten.view_copy.default(as_strided_copy_2, [-1]); as_strided_copy_2 = None + view_copy_6 = torch.ops.aten.view_copy.default(as_strided_copy_2, [-1]); as_strided_copy_2 = view_copy_6 = None view_copy_7 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]); view_copy_4 = None as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(view_copy_7, [3, 3], [3, 1]); view_copy_7 = None add_2 = torch.ops.aten.add.Tensor(as_strided_copy_3, 1); as_strided_copy_3 = None @@ -1713,20 +1713,20 @@ def forward(self, arg0_1): def forward(self, arg0_1): add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None view = torch.ops.aten.view.default(add, [4, 4]) - resize = torch.ops.aten.resize.default(view, [3, 3]) + resize = torch.ops.aten.resize.default(view, [3, 3]); resize = None as_strided = torch.ops.aten.as_strided.default(view, [3, 3], [3, 1]); view = None view_1 = torch.ops.aten.view.default(as_strided, [-1]); as_strided = None - add_1 = torch.ops.aten.add_.Tensor(view_1, 1) + add_1 = torch.ops.aten.add_.Tensor(view_1, 1); add_1 = None view_2 = torch.ops.aten.view.default(add, [4, 4]); add = None - as_strided_1 = torch.ops.aten.as_strided.default(view_2, [3, 3], [3, 1]) - view_3 = torch.ops.aten.view.default(view_1, [3, 3]); view_1 = None + as_strided_1 = torch.ops.aten.as_strided.default(view_2, [3, 3], [3, 1]); as_strided_1 = None + view_3 = torch.ops.aten.view.default(view_1, [3, 3]); view_1 = view_3 = None view_4 = torch.ops.aten.view.default(view_2, [8, 2]); view_2 = None view_5 = torch.ops.aten.view.default(view_4, [4, 4]) as_strided_2 = torch.ops.aten.as_strided.default(view_5, [3, 3], [3, 1]); view_5 = None - view_6 = torch.ops.aten.view.default(as_strided_2, [-1]); as_strided_2 = None + view_6 = torch.ops.aten.view.default(as_strided_2, [-1]); as_strided_2 = view_6 = None view_7 = torch.ops.aten.view.default(view_4, [4, 4]); view_4 = None as_strided_3 = torch.ops.aten.as_strided.default(view_7, [3, 3], [3, 1]); view_7 = None - add_2 = torch.ops.aten.add_.Tensor(as_strided_3, 1) + add_2 = torch.ops.aten.add_.Tensor(as_strided_3, 1); add_2 = None return as_strided_3 """, ) @@ -1770,7 +1770,7 @@ def forward(self, arg0_1): view_copy = torch.ops.aten.view_copy.default(resize, [25]); resize = None fill = torch.ops.aten.fill.Scalar(view_copy, 1); view_copy = None view_copy_1 = torch.ops.aten.view_copy.default(fill, [5, 5]); fill = None - view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [25]) + view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [25]); view_copy_2 = None add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1) return (view_copy_1, add_1) """, @@ -1787,11 +1787,11 @@ def forward(self, arg0_1): def forward(self, arg0_1): add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None - resize = torch.ops.aten.resize_.default(add, [5, 5]) + resize = torch.ops.aten.resize_.default(add, [5, 5]); resize = None view = torch.ops.aten.view.default(add, [25]); add = None - fill = torch.ops.aten.fill_.Scalar(view, 1) + fill = torch.ops.aten.fill_.Scalar(view, 1); fill = None view_1 = torch.ops.aten.view.default(view, [5, 5]); view = None - view_2 = torch.ops.aten.view.default(view_1, [25]) + view_2 = torch.ops.aten.view.default(view_1, [25]); view_2 = None add_1 = torch.ops.aten.add.Tensor(view_1, 1) return (view_1, add_1) """, @@ -1883,7 +1883,7 @@ def forward(self, arg0_1): select_copy = torch.ops.aten.select_copy.int(zeros, 0, 5) fill = torch.ops.aten.fill.Scalar(select_copy, 1); select_copy = None select_scatter = torch.ops.aten.select_scatter.default(zeros, fill, 0, 5); zeros = fill = None - select_copy_1 = torch.ops.aten.select_copy.int(select_scatter, 0, 5) + select_copy_1 = torch.ops.aten.select_copy.int(select_scatter, 0, 5); select_copy_1 = None return select_scatter """, ) # noqa: B950 @@ -1900,8 +1900,8 @@ def forward(self, arg0_1): def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False) select = torch.ops.aten.select.int(zeros, 0, 5) - fill = torch.ops.aten.fill_.Scalar(select, 1); select = None - select_1 = torch.ops.aten.select.int(zeros, 0, 5) + fill = torch.ops.aten.fill_.Scalar(select, 1); select = fill = None + select_1 = torch.ops.aten.select.int(zeros, 0, 5); select_1 = None return zeros """, ) @@ -1943,30 +1943,30 @@ def forward(self, arg0_1, arg1_1, arg2_1): repeat = torch.ops.aten.repeat.default(arg1_1, [20]) repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20]) view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None - empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) + empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')); empty = None _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None getitem = _native_batch_norm_legit_functional[0] - getitem_1 = _native_batch_norm_legit_functional[1] - getitem_2 = _native_batch_norm_legit_functional[2] + getitem_1 = _native_batch_norm_legit_functional[1]; getitem_1 = None + getitem_2 = _native_batch_norm_legit_functional[2]; getitem_2 = None getitem_3 = _native_batch_norm_legit_functional[3] getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None alias_copy = torch.ops.aten.alias_copy.default(arg1_1) - view_copy_1 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]) + view_copy_1 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]); view_copy_1 = None view_copy_2 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]); getitem_3 = None mean = torch.ops.aten.mean.dim(view_copy_2, [0]); view_copy_2 = None copy = torch.ops.aten.copy.default(alias_copy, mean); alias_copy = mean = None alias_copy_1 = torch.ops.aten.alias_copy.default(copy); copy = None - alias_copy_2 = torch.ops.aten.alias_copy.default(alias_copy_1) + alias_copy_2 = torch.ops.aten.alias_copy.default(alias_copy_1); alias_copy_2 = None alias_copy_3 = torch.ops.aten.alias_copy.default(arg2_1) - view_copy_3 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]) + view_copy_3 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]); view_copy_3 = None view_copy_4 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]); getitem_4 = None mean_1 = torch.ops.aten.mean.dim(view_copy_4, [0]); view_copy_4 = None copy_1 = torch.ops.aten.copy.default(alias_copy_3, mean_1); alias_copy_3 = mean_1 = None alias_copy_4 = torch.ops.aten.alias_copy.default(copy_1); copy_1 = None - alias_copy_5 = torch.ops.aten.alias_copy.default(alias_copy_4) + alias_copy_5 = torch.ops.aten.alias_copy.default(alias_copy_4); alias_copy_5 = None view_copy_5 = torch.ops.aten.view_copy.default(getitem, [20, 100, 35, 45]); getitem = None - copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_1); arg1_1 = alias_copy_1 = None - copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_4); arg2_1 = alias_copy_4 = None + copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_1); arg1_1 = alias_copy_1 = copy_ = None + copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_4); arg2_1 = alias_copy_4 = copy__1 = None return view_copy_5 """, # noqa: B950 ) @@ -1989,30 +1989,30 @@ def forward(self, arg0_1, arg1_1, arg2_1): repeat = torch.ops.aten.repeat.default(arg1_1, [20]) repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20]) view = torch.ops.aten.view.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None - empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) + empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')); empty = None _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None getitem = _native_batch_norm_legit_functional[0] - getitem_1 = _native_batch_norm_legit_functional[1] - getitem_2 = _native_batch_norm_legit_functional[2] + getitem_1 = _native_batch_norm_legit_functional[1]; getitem_1 = None + getitem_2 = _native_batch_norm_legit_functional[2]; getitem_2 = None getitem_3 = _native_batch_norm_legit_functional[3] getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None alias = torch.ops.aten.alias.default(arg1_1) - view_1 = torch.ops.aten.view.default(getitem_3, [20, 100]) + view_1 = torch.ops.aten.view.default(getitem_3, [20, 100]); view_1 = None view_2 = torch.ops.aten.view.default(getitem_3, [20, 100]); getitem_3 = None mean = torch.ops.aten.mean.dim(view_2, [0]); view_2 = None copy = torch.ops.aten.copy.default(alias, mean); alias = mean = None alias_1 = torch.ops.aten.alias.default(copy); copy = None - alias_2 = torch.ops.aten.alias.default(alias_1) + alias_2 = torch.ops.aten.alias.default(alias_1); alias_2 = None alias_3 = torch.ops.aten.alias.default(arg2_1) - view_3 = torch.ops.aten.view.default(getitem_4, [20, 100]) + view_3 = torch.ops.aten.view.default(getitem_4, [20, 100]); view_3 = None view_4 = torch.ops.aten.view.default(getitem_4, [20, 100]); getitem_4 = None mean_1 = torch.ops.aten.mean.dim(view_4, [0]); view_4 = None copy_1 = torch.ops.aten.copy.default(alias_3, mean_1); alias_3 = mean_1 = None alias_4 = torch.ops.aten.alias.default(copy_1); copy_1 = None - alias_5 = torch.ops.aten.alias.default(alias_4) + alias_5 = torch.ops.aten.alias.default(alias_4); alias_5 = None view_5 = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]); getitem = None - copy_ = torch.ops.aten.copy_.default(arg1_1, alias_1); arg1_1 = alias_1 = None - copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_4); arg2_1 = alias_4 = None + copy_ = torch.ops.aten.copy_.default(arg1_1, alias_1); arg1_1 = alias_1 = copy_ = None + copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_4); arg2_1 = alias_4 = copy__1 = None return view_5 """, # noqa: B950 ) @@ -2052,15 +2052,15 @@ def forward(self, arg0_1, arg1_1, arg2_1): def forward(self, arg0_1, arg1_1, arg2_1): - empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) + empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')); empty = None _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None getitem = _native_batch_norm_legit_functional[0] - getitem_1 = _native_batch_norm_legit_functional[1] - getitem_2 = _native_batch_norm_legit_functional[2] + getitem_1 = _native_batch_norm_legit_functional[1]; getitem_1 = None + getitem_2 = _native_batch_norm_legit_functional[2]; getitem_2 = None getitem_3 = _native_batch_norm_legit_functional[3] getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None - copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = None - copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = None + copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = copy_ = None + copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = copy__1 = None return getitem """, # noqa: B950 ) @@ -2080,15 +2080,15 @@ def forward(self, arg0_1, arg1_1, arg2_1): def forward(self, arg0_1, arg1_1, arg2_1): - empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) + empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')); empty = None _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None getitem = _native_batch_norm_legit_functional[0] - getitem_1 = _native_batch_norm_legit_functional[1] - getitem_2 = _native_batch_norm_legit_functional[2] + getitem_1 = _native_batch_norm_legit_functional[1]; getitem_1 = None + getitem_2 = _native_batch_norm_legit_functional[2]; getitem_2 = None getitem_3 = _native_batch_norm_legit_functional[3] getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None - copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = None - copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = None + copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = copy_ = None + copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = copy__1 = None return getitem """, # noqa: B950 ) @@ -2129,9 +2129,9 @@ def forward(self, arg0_1, arg1_1, arg2_1): fx_g.code.strip(), """\ def forward(self, x_1): - view = torch.ops.aten.view.default(x_1, [-1]) + view = torch.ops.aten.view.default(x_1, [-1]); view = None mul = torch.ops.aten.mul.Tensor(x_1, 2); x_1 = None - view_1 = torch.ops.aten.view.default(mul, [-1]) + view_1 = torch.ops.aten.view.default(mul, [-1]); view_1 = None view_2 = torch.ops.aten.view.default(mul, [-1]); mul = None add = torch.ops.aten.add.Tensor(view_2, 1); view_2 = None return add""", diff --git a/test/test_fx.py b/test/test_fx.py index a58abb906d89..34426b0bd45a 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -3851,6 +3851,31 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: self.assertIs(next(iter(a.users.keys())), output_node) m.graph.lint() + def test_delete_unused_values(self): + from torch.fx.experimental.proxy_tensor import make_fx + + # disable mutable checking temporarily + orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations + torch.fx.proxy.TracerBase.check_mutable_operations = False + + def fn(a, b, c, d): + x = a + b + y = c + d + y.copy_(x) + x = torch.relu(x) + return x + + a, b, c, d = (torch.randn(2, 4, requires_grad=False) for _ in range(4)) + fx_fn = make_fx(fn)(a, b, c, d) + print(fx_fn) + + fx_fn.graph.eliminate_dead_code() + py_code = fx_fn.recompile() + self.assertTrue("copy_ = torch.ops.aten.copy_.default" in py_code.src) + self.assertTrue("copy_ = None" in py_code.src) + + # recorver mutable checking flag + torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag def run_getitem_target(): from torch.fx._symbolic_trace import _wrapped_methods_to_patch diff --git a/test/test_fx_reinplace_pass.py b/test/test_fx_reinplace_pass.py index 3fe1d59d17b7..ce37162eb8f0 100644 --- a/test/test_fx_reinplace_pass.py +++ b/test/test_fx_reinplace_pass.py @@ -31,7 +31,7 @@ class TestReinplacePass(TestCase): def forward(self, x_1): clone = torch.ops.aten.clone.default(x_1); x_1 = None - add = torch.ops.aten.add_.Tensor(clone, 1) + add = torch.ops.aten.add_.Tensor(clone, 1); add = None return clone """) @@ -58,8 +58,8 @@ def forward(self, x_1): def forward(self, x_1): clone = torch.ops.aten.clone.default(x_1); x_1 = None view = torch.ops.aten.view.default(clone, [-1]) - add = torch.ops.aten.add.Tensor(clone, 1); clone = None - add_1 = torch.ops.aten.add_.Tensor(view, 1) + add = torch.ops.aten.add.Tensor(clone, 1); clone = add = None + add_1 = torch.ops.aten.add_.Tensor(view, 1); add_1 = None return view """) @@ -144,20 +144,20 @@ def forward(self, a__1): def forward(self, a__1): clone = torch.ops.aten.clone.default(a__1); a__1 = None - view = torch.ops.aten.view.default(clone, [-1]) + view = torch.ops.aten.view.default(clone, [-1]); view = None view_1 = torch.ops.aten.view.default(clone, [-1]) select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None view_2 = torch.ops.aten.view.default(select, [-1]); select = None - add = torch.ops.aten.add_.Tensor(view_2, 1) + add = torch.ops.aten.add_.Tensor(view_2, 1); add = None view_3 = torch.ops.aten.view.default(clone, [-1]); clone = None - select_1 = torch.ops.aten.select.int(view_3, 0, 0) - view_4 = torch.ops.aten.view.default(view_2, []); view_2 = None + select_1 = torch.ops.aten.select.int(view_3, 0, 0); select_1 = None + view_4 = torch.ops.aten.view.default(view_2, []); view_2 = view_4 = None view_5 = torch.ops.aten.view.default(view_3, [4]); view_3 = None view_6 = torch.ops.aten.view.default(view_5, [-1]) select_2 = torch.ops.aten.select.int(view_6, 0, 0); view_6 = None - view_7 = torch.ops.aten.view.default(select_2, [-1]); select_2 = None + view_7 = torch.ops.aten.view.default(select_2, [-1]); select_2 = view_7 = None view_8 = torch.ops.aten.view.default(view_5, [-1]) - add_1 = torch.ops.aten.add_.Tensor(view_5, view_8); view_8 = None + add_1 = torch.ops.aten.add_.Tensor(view_5, view_8); view_8 = add_1 = None return view_5 """) @@ -187,12 +187,12 @@ def forward(self, a__1): slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None select_1 = torch.ops.aten.select.int(select, 0, 1); select = None - add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = None + add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) - select_2 = torch.ops.aten.select.int(slice_2, 1, 1); slice_2 = None + select_2 = torch.ops.aten.select.int(slice_2, 1, 1); slice_2 = select_2 = None slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) select_3 = torch.ops.aten.select.int(slice_3, 1, 1); slice_3 = None - select_4 = torch.ops.aten.select.int(select_3, 0, 1); select_3 = None + select_4 = torch.ops.aten.select.int(select_3, 0, 1); select_3 = select_4 = None return clone """) @@ -227,7 +227,7 @@ def forward(self, a__1): slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None select_1 = torch.ops.aten.select.int(select, 0, 1); select = None - add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = None + add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None return as_strided """) @@ -266,7 +266,7 @@ def forward(self, a__1): add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None select_int = torch.ops.aten.select.int(as_strided, 0, 0) - copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = None + copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = copy__default = None return as_strided """) # noqa: B950 @@ -299,7 +299,7 @@ def forward(self, a__1): add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 0); clone = None select_int = torch.ops.aten.select.int(as_strided, 0, 1) - copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = None + copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = copy__default = None return as_strided """) # noqa: B950 @@ -326,7 +326,7 @@ def forward(self, a__1): def forward(self): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal = torch.ops.aten.diagonal.default(zeros) - add = torch.ops.aten.add_.Tensor(diagonal, 1); diagonal = None + add = torch.ops.aten.add_.Tensor(diagonal, 1); diagonal = add = None return [zeros] """) @@ -351,10 +351,10 @@ def forward(self): ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False) slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None - copy = torch.ops.aten.copy_.default(slice_2, ones); slice_2 = ones = None - slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) + copy = torch.ops.aten.copy_.default(slice_2, ones); slice_2 = ones = copy = None + slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807); slice_3 = None slice_4 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) - slice_5 = torch.ops.aten.slice.Tensor(slice_4, 1, 2, 9223372036854775807); slice_4 = None + slice_5 = torch.ops.aten.slice.Tensor(slice_4, 1, 2, 9223372036854775807); slice_4 = slice_5 = None return zeros """) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index d0fd7ae7c096..f9fe889d1d40 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1008,7 +1008,7 @@ def forward(self, x_1, y_1): self.assertExpectedInline(r, """\ def forward(self, x_1, y_1): sym_size_int = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None - resize_ = torch.ops.aten.resize_.default(x_1, [sym_size_int]); x_1 = sym_size_int = None + resize_ = torch.ops.aten.resize_.default(x_1, [sym_size_int]); x_1 = sym_size_int = resize_ = None return None""") def test_broadcast_shapes(self): @@ -1303,7 +1303,7 @@ def forward(self, x_1): sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); sym_size_int = None select = torch.ops.aten.select.int(x_1, 0, 0) - copy_ = torch.ops.aten.copy_.default(select, scalar_tensor); select = scalar_tensor = None + copy_ = torch.ops.aten.copy_.default(select, scalar_tensor); select = scalar_tensor = copy_ = None return x_1""" # noqa: B950 ) @@ -1321,7 +1321,7 @@ def forward(self, gravity_1, mask_1): index = torch.ops.aten.index.Tensor(select, [mask_1]); select = None mul = torch.ops.aten.mul.Tensor(index, -1); index = None select_1 = torch.ops.aten.select.int(gravity_1, 1, 0); gravity_1 = None - index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul); select_1 = mask_1 = mul = None + index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul); select_1 = mask_1 = mul = index_put_ = None return None""") def test_reflect_r_over_x(self): @@ -1345,7 +1345,7 @@ def forward(self, crop_camera_1, mask_1): lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None select = torch.ops.aten.select.int(eye, 0, 0) select_1 = torch.ops.aten.select.int(select, 0, 0); select = None - copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = None + copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = copy_ = None sym_size_int = torch.ops.aten.sym_size.int(index, 0) expand = torch.ops.aten.expand.default(eye, [sym_size_int, 3, 3]) view = torch.ops.aten.view.default(expand, [sym_size_int, 3, 3]); expand = None @@ -1359,7 +1359,7 @@ def forward(self, crop_camera_1, mask_1): view_3 = torch.ops.aten.view.default(view_2, [mul, 3]); view_2 = mul = None mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None - index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4); crop_camera_1 = mask_1 = view_4 = None + index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4); crop_camera_1 = mask_1 = view_4 = index_put_ = None return None""") # noqa: B950 def test_unbacked_slice(self): @@ -1412,7 +1412,7 @@ def forward(self, images_1, handedness_1, valid_1): eq = torch.ops.aten.eq.Scalar(index_1, 1); index_1 = None index_2 = torch.ops.aten.index.Tensor(index, [eq]) flip = torch.ops.aten.flip.default(index_2, [-1]); index_2 = None - index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip); index = eq = flip = None + index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip); index = eq = flip = index_put_ = None return None""") def test_neg_shape(self): @@ -1481,7 +1481,7 @@ def forward(self, x_1, y_1): self.assertExpectedInline(r, """\ def forward(self, x_1, y_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None - zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None + zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = zeros = None add = torch.ops.aten.add.Tensor(y_1, 2); y_1 = None return add""") # noqa: B950 @@ -1566,9 +1566,9 @@ def forward(self, lengths_1, values_1): _local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None select_2 = torch.ops.aten.select.int(lengths_1, 0, 2); lengths_1 = None _local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2); select_2 = None - sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense) - sym_constrain_range_for_size_1 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_1) - sym_constrain_range_for_size_2 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_2) + sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense); sym_constrain_range_for_size = None + sym_constrain_range_for_size_1 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_1); sym_constrain_range_for_size_1 = None + sym_constrain_range_for_size_2 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_2); sym_constrain_range_for_size_2 = None split_with_sizes = torch.ops.aten.split_with_sizes.default(values_1, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2]); values_1 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = None getitem = split_with_sizes[0] getitem_1 = split_with_sizes[1] diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 60c864b43b79..0c5687bc376f 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -528,6 +528,13 @@ class CodeGen: body.append('\n') return nodes_to_delete = user_to_last_uses.get(user, []) + + if len(user.users.keys()) == 0: + # This node is not used by any others. however it's also not + # removed by DCE since side-effect. We want to free it's outputs + # right after its execution done to save memory. + nodes_to_delete.append(user) + if len(nodes_to_delete): to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) body.append(f'; {dim(to_delete_str)}\n')