diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index 9045026a317f..9b8ee438cfad 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -885,6 +885,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn 1|aten._native_batch_norm_legit_functional.default|l__self___bn1| 2|aten.relu.default|l__self___relu1| 2|aten.detach.default|l__self___relu1| +2|aten.detach.default|l__self___relu1| 3|aten.add.Tensor|add| 4|aten.view.default|flatten| 5|aten.view.default|l__self___fc1| @@ -911,6 +912,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn 5|aten.view.default||l__self___fc1 4|aten.view.default||flatten 2|aten.detach.default||l__self___relu1 +2|aten.detach.default||l__self___relu1 2|aten.threshold_backward.default||l__self___relu1 1|aten.native_batch_norm_backward.default||l__self___bn1 0|aten.convolution_backward.default||l__self___conv1 diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 6e9379be092e..501b08e65901 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -45,9 +45,11 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None _softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None alias = torch.ops.aten.alias.default(_softmax) + alias_1 = torch.ops.aten.alias.default(alias); alias = None clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None _log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None - alias_1 = torch.ops.aten.alias.default(_log_softmax) + alias_2 = torch.ops.aten.alias.default(_log_softmax) + alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None neg = torch.ops.aten.neg.default(sum_1); sum_1 = None @@ -57,15 +59,17 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None - alias_2 = torch.ops.aten.alias.default(alias_1); alias_1 = None - exp = torch.ops.aten.exp.default(alias_2); alias_2 = None + alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None + alias_5 = torch.ops.aten.alias.default(alias_4); alias_4 = None + exp = torch.ops.aten.exp.default(alias_5); alias_5 = None sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True) mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None - alias_3 = torch.ops.aten.alias.default(alias); alias = None - mul_3 = torch.ops.aten.mul.Tensor(sub, alias_3); sub = None + alias_6 = torch.ops.aten.alias.default(alias_1); alias_1 = None + alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None + mul_3 = torch.ops.aten.mul.Tensor(sub, alias_7); sub = None sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True) - mul_4 = torch.ops.aten.mul.Tensor(alias_3, sum_3); alias_3 = sum_3 = None + mul_4 = torch.ops.aten.mul.Tensor(alias_7, sum_3); alias_7 = sum_3 = None sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None permute_1 = torch.ops.aten.permute.default(view_2, [1, 0]) @@ -87,9 +91,11 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None _softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None alias = torch.ops.aten.alias.default(_softmax) + alias_1 = torch.ops.aten.alias.default(alias); alias = None clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None _log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None - alias_1 = torch.ops.aten.alias.default(_log_softmax) + alias_2 = torch.ops.aten.alias.default(_log_softmax) + alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None neg = torch.ops.aten.neg.default(sum_1); sum_1 = None @@ -99,15 +105,17 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None - alias_2 = torch.ops.aten.alias.default(alias_1); alias_1 = None - exp = torch.ops.aten.exp.default(alias_2); alias_2 = None + alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None + alias_5 = torch.ops.aten.alias.default(alias_4); alias_4 = None + exp = torch.ops.aten.exp.default(alias_5); alias_5 = None sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True) mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None - alias_3 = torch.ops.aten.alias.default(alias); alias = None - mul_3 = torch.ops.aten.mul.Tensor(sub, alias_3); sub = None + alias_6 = torch.ops.aten.alias.default(alias_1); alias_1 = None + alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None + mul_3 = torch.ops.aten.mul.Tensor(sub, alias_7); sub = None sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True) - mul_4 = torch.ops.aten.mul.Tensor(alias_3, sum_3); alias_3 = sum_3 = None + mul_4 = torch.ops.aten.mul.Tensor(alias_7, sum_3); alias_7 = sum_3 = None sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None permute_1 = torch.ops.aten.permute.default(view_2, [1, 0]) diff --git a/test/export/test_export.py b/test/export/test_export.py index 566c7efe425a..f7b44c24d82e 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -10373,9 +10373,13 @@ graph(): %x : [num_users=2] = placeholder[target=x] %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False}) %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {}) + %detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach,), kwargs = {}) + %detach_2 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_1,), kwargs = {}) %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {}) - %detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {}) - %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach, %detach_1), kwargs = {}) + %detach_3 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {}) + %detach_4 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_3,), kwargs = {}) + %detach_5 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_4,), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach_2, %detach_5), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) return (mul_1,)""", diff --git a/test/functorch/test_aot_joint_with_descriptors.py b/test/functorch/test_aot_joint_with_descriptors.py index 44a562d9ae9a..6b80af961e06 100644 --- a/test/functorch/test_aot_joint_with_descriptors.py +++ b/test/functorch/test_aot_joint_with_descriptors.py @@ -214,7 +214,9 @@ class inner_f(torch.nn.Module): where: "f32[2, 3, 4, 4]" = torch.ops.prims.where.default(le, 0.0, add_4); le = add_4 = None view_of: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(where) view_of_1: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of); view_of = None - le_1: "b8[2, 3, 4, 4]" = torch.ops.prims.le.default(view_of_1, 0.0); view_of_1 = None + view_of_2: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of_1); view_of_1 = None + view_of_3: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of_2); view_of_2 = None + le_1: "b8[2, 3, 4, 4]" = torch.ops.prims.le.default(view_of_3, 0.0); view_of_3 = None where_1: "f32[2, 3, 4, 4]" = torch.ops.prims.where.default(le_1, 0.0, tangents_1); le_1 = tangents_1 = None broadcast_in_dim_10: "f32[1, 3]" = torch.ops.prims.broadcast_in_dim.default(squeeze_2, [1, 3], [1]); squeeze_2 = None broadcast_in_dim_11: "f32[1, 3, 1]" = torch.ops.prims.broadcast_in_dim.default(broadcast_in_dim_10, [1, 3, 1], [0, 1]); broadcast_in_dim_10 = None diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index a7159a06d028..080002999964 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2278,7 +2278,9 @@ def forward(self, primals_1): view = torch.ops.aten.view.default(mul, [-1]) select = torch.ops.aten.select.int(mul, 0, 0) detach = torch.ops.aten.detach.default(select); select = None - return (view, mul, detach)""", + detach_1 = torch.ops.aten.detach.default(detach); detach = None + detach_2 = torch.ops.aten.detach.default(detach_1); detach_1 = None + return (view, mul, detach_2)""", ) def test_output_aliases_intermediate_inplace_view(self): @@ -5136,12 +5138,23 @@ class (torch.nn.Module): relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); detach = None detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu) + detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None + detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_2); detach_2 = None + detach_4: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_3); detach_3 = None sum_1: "f32[]" = torch.ops.aten.sum.default(relu) - detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None + detach_5: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None + detach_6: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_5); detach_5 = None + detach_7: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_6); detach_6 = None + detach_8: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_7); detach_7 = None + detach_9: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_8); detach_8 = None + detach_10: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_9); detach_9 = None ones_like: "f32[]" = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format) expand: "f32[1, 3, 3, 3]" = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]); ones_like = None - detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None - threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_3, 0); expand = detach_3 = None + detach_11: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_4); detach_4 = None + detach_12: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_11); detach_11 = None + detach_13: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_12); detach_12 = None + detach_14: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_13); detach_13 = None + threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_14, 0); expand = detach_14 = None native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None getitem_5: "f32[1, 3, 3, 3]" = native_batch_norm_backward[0] getitem_6: "f32[3]" = native_batch_norm_backward[1] @@ -5150,7 +5163,7 @@ class (torch.nn.Module): getitem_8 = convolution_backward[0]; getitem_8 = None getitem_9: "f32[3, 1, 1, 1]" = convolution_backward[1] getitem_10: "f32[3]" = convolution_backward[2]; convolution_backward = None - return (getitem_3, getitem_4, add, sum_1, detach_2, getitem_9, getitem_10, getitem_6, getitem_7) + return (getitem_3, getitem_4, add, sum_1, detach_10, getitem_9, getitem_10, getitem_6, getitem_7) """, # noqa: B950 ) @@ -5218,12 +5231,14 @@ class (torch.nn.Module): relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None sum_1: "f32[]" = torch.ops.aten.sum.default(relu) detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None + detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach); detach = None + detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None return ( getitem_3, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=4)) getitem_4, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=5)) add, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=6)) sum_1, # PlainAOTOutput(idx=0) - detach, # PlainAOTOutput(idx=1) + detach_2, # PlainAOTOutput(idx=1) ) """, # noqa: B950 ) diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py index 9f630119363f..f9821d1bf3a2 100644 --- a/test/profiler/test_memory_profiler.py +++ b/test/profiler/test_memory_profiler.py @@ -1174,10 +1174,12 @@ class TestMemoryProfilerE2E(TestCase): aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 20 (AUTOGRAD_DETAIL) aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT) aten::view 21 (GRADIENT) -> 21 (GRADIENT) + aten::detach 21 (GRADIENT) -> 21 (GRADIENT) aten::detach 21 (GRADIENT) -> ??? aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL) aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT) aten::view 23 (GRADIENT) -> 23 (GRADIENT) + aten::detach 23 (GRADIENT) -> 23 (GRADIENT) aten::detach 23 (GRADIENT) -> ???""", ) @@ -1225,10 +1227,12 @@ class TestMemoryProfilerE2E(TestCase): aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT) aten::view 21 (GRADIENT) -> 21 (GRADIENT) aten::detach 21 (GRADIENT) -> 21 (GRADIENT) + aten::detach 21 (GRADIENT) -> 21 (GRADIENT) aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL) aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT) aten::view 23 (GRADIENT) -> 23 (GRADIENT) aten::detach 23 (GRADIENT) -> 23 (GRADIENT) + aten::detach 23 (GRADIENT) -> 23 (GRADIENT) -- Optimizer -------------------------------------------------------------------------------------------- aten::add_.Tensor 3 (PARAMETER), 23 (GRADIENT) -> 3 (PARAMETER) @@ -1273,8 +1277,10 @@ class TestMemoryProfilerE2E(TestCase): aten::t 7 (GRADIENT) -> 7 (GRADIENT) aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT) aten::view 9 (GRADIENT) -> 9 (GRADIENT) + aten::detach 9 (GRADIENT) -> 9 (GRADIENT) aten::detach 9 (GRADIENT) -> ??? aten::t 7 (GRADIENT) -> 7 (GRADIENT) + aten::detach 7 (GRADIENT) -> 7 (GRADIENT) aten::detach 7 (GRADIENT) -> ???""", ) @@ -1312,14 +1318,18 @@ class TestMemoryProfilerE2E(TestCase): aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT) aten::view 9 (GRADIENT) -> 9 (GRADIENT) aten::detach 9 (GRADIENT) -> 9 (GRADIENT) + aten::detach 9 (GRADIENT) -> 9 (GRADIENT) aten::t 7 (GRADIENT) -> 7 (GRADIENT) aten::detach 7 (GRADIENT) -> 7 (GRADIENT) + aten::detach 7 (GRADIENT) -> 7 (GRADIENT) -- Optimizer -------------------------------------------------------------------------------------------- aten::detach 7 (GRADIENT) -> 7 (GRADIENT) + aten::detach 7 (GRADIENT) -> 7 (GRADIENT) aten::clone 7 (GRADIENT) -> 10 (OPTIMIZER_STATE) aten::add_.Tensor 2 (PARAMETER), 10 (OPTIMIZER_STATE) -> 2 (PARAMETER) aten::detach 9 (GRADIENT) -> 9 (GRADIENT) + aten::detach 9 (GRADIENT) -> 9 (GRADIENT) aten::clone 9 (GRADIENT) -> 11 (OPTIMIZER_STATE) aten::add_.Tensor 3 (PARAMETER), 11 (OPTIMIZER_STATE) -> 3 (PARAMETER)""", ) @@ -1404,6 +1414,7 @@ class TestMemoryProfilerE2E(TestCase): aten::t 7 (PARAMETER) -> 7 (PARAMETER) aten::mm 25 (AUTOGRAD_DETAIL), 7 (PARAMETER) -> 27 (AUTOGRAD_DETAIL) aten::t 26 (GRADIENT) -> 26 (GRADIENT) + aten::detach 26 (GRADIENT) -> 26 (GRADIENT) aten::detach 26 (GRADIENT) -> ??? aten::detach 6 (ACTIVATION) -> 6 (ACTIVATION) aten::threshold_backward 27 (AUTOGRAD_DETAIL), 6 (ACTIVATION) -> 28 (AUTOGRAD_DETAIL) @@ -1412,8 +1423,10 @@ class TestMemoryProfilerE2E(TestCase): aten::t 29 (GRADIENT) -> 29 (GRADIENT) aten::sum.dim_IntList 28 (AUTOGRAD_DETAIL) -> 30 (GRADIENT) aten::view 30 (GRADIENT) -> 30 (GRADIENT) + aten::detach 30 (GRADIENT) -> 30 (GRADIENT) aten::detach 30 (GRADIENT) -> ??? aten::t 29 (GRADIENT) -> 29 (GRADIENT) + aten::detach 29 (GRADIENT) -> 29 (GRADIENT) aten::detach 29 (GRADIENT) -> ???""", ) diff --git a/test/test_autograd.py b/test/test_autograd.py index 7fda4f9df806..021659b81122 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -4926,6 +4926,7 @@ Running aten.expand.default from within SumBackward0 Running aten.div.Tensor from within DivBackward0 Running aten.mul.Tensor from within MulBackward0 Running aten.detach.default from within AccumulateGrad +Running aten.detach.default from within AccumulateGrad Done""", ) @@ -7198,7 +7199,9 @@ for shape in [(1,), ()]: lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn ) out.backward() - self.assertEqual(verbose_mode.operators, ["exp.default", "detach.default"]) + self.assertEqual( + verbose_mode.operators, ["exp.default", "detach.default", "detach.default"] + ) with self.assertRaisesRegex( Exception, "only supported when use_reentrant=False" diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 98fbabff11ef..07a92244cd73 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -850,7 +850,7 @@ $1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""", lambda: A(torch.zeros(1)).detach(), ) - def test_detach_appears_once_when_called_once(self) -> None: + def test_detach_appears_twice_when_called_once(self) -> None: with capture_logs() as logs: x = LoggingTensor(torch.tensor([3.0]), requires_grad=True) log_input("x", x) @@ -863,7 +863,8 @@ $1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""", "\n".join(logs), """\ $0: f32[1] = input('x') -$1: f32[1] = torch._ops.aten.detach.default($0)""", +$1: f32[1] = torch._ops.aten.detach.default($0) +$2: f32[1] = torch._ops.aten.detach.default($1)""", ) def test_storage(self) -> None: diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index c2c4dffee66e..e270df51221b 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -453,18 +453,20 @@ static Tensor detach(c10::DispatchKeySet ks, const Tensor& self) { return at::_ops::detach::redispatch( ks & c10::after_ADInplaceOrView_keyset, self); })(); - // NB: we can't make detach() a normal view operator because the - // codegen generates allow_tensor_metadata_change = True (and leaves - // is_fresh_tensor to the default setting of False) for them. In the - // future we should have an option for this in the codegen. - if (self.is_inference()) { - return out; - } - return ::torch::autograd::make_variable_non_differentiable_view( - self, - out, - /* allow_tensor_metadata_change */ false, - /* is_fresh_tensor */ true); + // NB: we can't make detach() a normal view operator because the codegen + // generates allow_tensor_metadata_change = True for them. In the future we + // should have an option for this in the codegen. + auto result = as_view( + /* base */ self, + /* output */ out, + /* is_bw_differentiable */ false, + /* is_fw_differentiable */ false, + /* view_func */ nullptr, + /* rev_view_func */ nullptr, + /* creation_meta */ CreationMeta::DEFAULT, + /*allow_tensor_metadata_change=*/false); + + return result; } static Tensor _fw_primal( diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index 3bd1f0aab6ae..dfffd3d97095 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -835,20 +835,11 @@ inline Variable make_variable_differentiable_view( inline Variable make_variable_non_differentiable_view( const Variable& base, const at::Tensor& data, - bool allow_tensor_metadata_change = true, - bool is_fresh_tensor = false) { + bool allow_tensor_metadata_change = true) { if (data.defined()) { - // If we already allocated a new tensor, no need to - // shallow_copy_and_detach here. (See #163671 history; we tried to - // fan out to _indices and _values and ran into a SparseTensorImpl - // can of worms.) - if (is_fresh_tensor) { - auto* data_impl = data.unsafeGetTensorImpl(); - data_impl->set_version_counter(impl::version_counter(base)); - data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); - data_impl->set_autograd_meta(nullptr); - return data; - } + // Currently all of non-differentiable view ops(detach/_indices/_values) + // share the same TensorImpl as their base Tensor. Thus a new TensorImpl + // allocation here is required. auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( /*version_counter=*/impl::version_counter(base), /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);