Fix double dispatch to Python for detach (#163671)

This fixes #71725.

Differential Revision: [D83857880](https://our.internmc.facebook.com/intern/diff/D83857880)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163671
Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
Scott Wolchok
2025-10-12 12:55:37 -07:00
committed by PyTorch MergeBot
parent 815d641599
commit 331b7cc054
13 changed files with 60 additions and 110 deletions

View File

@ -239,9 +239,7 @@ class DTensorExportTest(TestCase):
"view_9",
"t_15",
"detach",
"detach_1",
"detach_6",
"detach_7",
"detach_3",
"threshold_backward_1",
"t_16",
"mm_6",
@ -259,10 +257,8 @@ class DTensorExportTest(TestCase):
"sum_1",
"view_7",
"t_7",
"detach_1",
"detach_2",
"detach_3",
"detach_4",
"detach_5",
"threshold_backward",
"mm_2",
"t_9",

View File

@ -921,7 +921,6 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
1|aten._native_batch_norm_legit_functional.default|batch_norm|
2|aten.relu.default|relu|
2|aten.detach.default|relu|
2|aten.detach.default|relu|
3|aten.add.Tensor|add|
4|aten.view.default|flatten|
5|aten.view.default|linear|
@ -948,7 +947,6 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
5|aten.view.default||linear
4|aten.view.default||flatten
2|aten.detach.default||relu
2|aten.detach.default||relu
2|aten.threshold_backward.default||relu
1|aten.native_batch_norm_backward.default||batch_norm
0|aten.convolution_backward.default||conv2d

View File

@ -216,18 +216,16 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
('call_function', 'getitem', {'compile_inductor': 0})
('call_function', 'getitem_1', {'compile_inductor': 0})
('call_function', 'detach_1', {'compile_inductor': 0})
('call_function', 'detach_4', {'compile_inductor': 0})
('call_function', 'detach_5', {'compile_inductor': 0})""", # noqa: B950
('call_function', 'detach_3', {'compile_inductor': 0})""", # noqa: B950
)
self.assertExpectedInline(
str(bw_metadata),
"""\
('placeholder', 'getitem', {'compile_inductor': 0})
('placeholder', 'detach_5', {'compile_inductor': 0})
('placeholder', 'detach_3', {'compile_inductor': 0})
('call_function', 'zeros', {'compile_inductor': 0})
('call_function', 'detach', {'compile_inductor': 0})
('call_function', 'detach_2', {'compile_inductor': 0})
('call_function', 'detach_3', {'compile_inductor': 0})
('get_attr', 'fw_graph0', {'compile_inductor': 0})
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('placeholder', 'arg4_1', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('get_attr', 'joint_graph0', {'compile_inductor': 0})

View File

@ -684,11 +684,11 @@ class StructuredTraceTest(TestCase):
{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"describe_storage": {"id": 16, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_tensor": {"id": 29, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 29, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_tensor": {"id": 28, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 28, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_storage": {"id": 17, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_tensor": {"id": 30, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 17, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 30, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_tensor": {"id": 29, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 17, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 29, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}

View File

@ -45,11 +45,9 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
alias = torch.ops.aten.alias.default(_softmax)
alias_1 = torch.ops.aten.alias.default(alias); alias = None
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
alias_2 = torch.ops.aten.alias.default(_log_softmax)
alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None
alias_1 = torch.ops.aten.alias.default(_log_softmax)
mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
@ -59,17 +57,15 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None
alias_5 = torch.ops.aten.alias.default(alias_4); alias_4 = None
exp = torch.ops.aten.exp.default(alias_5); alias_5 = None
alias_2 = torch.ops.aten.alias.default(alias_1); alias_1 = None
exp = torch.ops.aten.exp.default(alias_2); alias_2 = None
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
alias_6 = torch.ops.aten.alias.default(alias_1); alias_1 = None
alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_7); sub = None
alias_3 = torch.ops.aten.alias.default(alias); alias = None
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_3); sub = None
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
mul_4 = torch.ops.aten.mul.Tensor(alias_7, sum_3); alias_7 = sum_3 = None
mul_4 = torch.ops.aten.mul.Tensor(alias_3, sum_3); alias_3 = sum_3 = None
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
@ -91,11 +87,9 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
alias = torch.ops.aten.alias.default(_softmax)
alias_1 = torch.ops.aten.alias.default(alias); alias = None
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
alias_2 = torch.ops.aten.alias.default(_log_softmax)
alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None
alias_1 = torch.ops.aten.alias.default(_log_softmax)
mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
@ -105,17 +99,15 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None
alias_5 = torch.ops.aten.alias.default(alias_4); alias_4 = None
exp = torch.ops.aten.exp.default(alias_5); alias_5 = None
alias_2 = torch.ops.aten.alias.default(alias_1); alias_1 = None
exp = torch.ops.aten.exp.default(alias_2); alias_2 = None
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
alias_6 = torch.ops.aten.alias.default(alias_1); alias_1 = None
alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_7); sub = None
alias_3 = torch.ops.aten.alias.default(alias); alias = None
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_3); sub = None
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
mul_4 = torch.ops.aten.mul.Tensor(alias_7, sum_3); alias_7 = sum_3 = None
mul_4 = torch.ops.aten.mul.Tensor(alias_3, sum_3); alias_3 = sum_3 = None
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])

View File

@ -1229,9 +1229,7 @@ def forward(self, primals, tangents):
t = torch.ops.aten.t.default(primals_1); primals_1 = None
addmm = torch.ops.aten.addmm.default(primals_2, primals_5, t); primals_2 = None
relu = torch.ops.aten.relu.default(addmm); addmm = None
detach_9 = torch.ops.aten.detach.default(relu)
detach_10 = torch.ops.aten.detach.default(detach_9); detach_9 = None
detach_11 = torch.ops.aten.detach.default(detach_10); detach_10 = None
detach_3 = torch.ops.aten.detach.default(relu)
t_1 = torch.ops.aten.t.default(primals_3); primals_3 = None
addmm_1 = torch.ops.aten.addmm.default(primals_4, relu, t_1); primals_4 = None
t_2 = torch.ops.aten.t.default(t_1); t_1 = None
@ -1242,9 +1240,8 @@ def forward(self, primals, tangents):
sum_1 = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True); tangents_1 = None
view = torch.ops.aten.view.default(sum_1, [128]); sum_1 = None
t_5 = torch.ops.aten.t.default(t_4); t_4 = None
detach_18 = torch.ops.aten.detach.default(detach_11); detach_11 = None
detach_19 = torch.ops.aten.detach.default(detach_18); detach_18 = None
threshold_backward = torch.ops.aten.threshold_backward.default(mm, detach_19, 0); mm = detach_19 = None
detach_6 = torch.ops.aten.detach.default(detach_3); detach_3 = None
threshold_backward = torch.ops.aten.threshold_backward.default(mm, detach_6, 0); mm = detach_6 = None
t_6 = torch.ops.aten.t.default(t); t = None
mm_2 = torch.ops.aten.mm.default(threshold_backward, t_6); t_6 = None
t_7 = torch.ops.aten.t.default(threshold_backward)
@ -10320,13 +10317,9 @@ graph():
%x : [num_users=2] = placeholder[target=x]
%ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False})
%detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {})
%detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach,), kwargs = {})
%detach_2 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_1,), kwargs = {})
%clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {})
%detach_3 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {})
%detach_4 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_3,), kwargs = {})
%detach_5 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_4,), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach_2, %detach_5), kwargs = {})
%detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach, %detach_1), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
return (mul_1,)""",

View File

@ -237,9 +237,7 @@ class inner_f(torch.nn.Module):
where: "f32[2, 3, 4, 4]" = torch.ops.prims.where.default(le, 0.0, add_4); le = add_4 = None
view_of: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(where)
view_of_1: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of); view_of = None
view_of_2: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of_1); view_of_1 = None
view_of_3: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of_2); view_of_2 = None
le_1: "b8[2, 3, 4, 4]" = torch.ops.prims.le.default(view_of_3, 0.0); view_of_3 = None
le_1: "b8[2, 3, 4, 4]" = torch.ops.prims.le.default(view_of_1, 0.0); view_of_1 = None
where_1: "f32[2, 3, 4, 4]" = torch.ops.prims.where.default(le_1, 0.0, tangents_1); le_1 = tangents_1 = None
broadcast_in_dim_10: "f32[1, 3]" = torch.ops.prims.broadcast_in_dim.default(squeeze_2, [1, 3], [1]); squeeze_2 = None
broadcast_in_dim_11: "f32[1, 3, 1]" = torch.ops.prims.broadcast_in_dim.default(broadcast_in_dim_10, [1, 3, 1], [0, 1]); broadcast_in_dim_10 = None

View File

@ -2278,9 +2278,7 @@ def forward(self, primals_1):
view = torch.ops.aten.view.default(mul, [-1])
select = torch.ops.aten.select.int(mul, 0, 0)
detach = torch.ops.aten.detach.default(select); select = None
detach_1 = torch.ops.aten.detach.default(detach); detach = None
detach_2 = torch.ops.aten.detach.default(detach_1); detach_1 = None
return (view, mul, detach_2)""",
return (view, mul, detach)""",
)
def test_output_aliases_intermediate_inplace_view(self):
@ -5138,23 +5136,12 @@ class <lambda>(torch.nn.Module):
relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None
detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); detach = None
detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu)
detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_2); detach_2 = None
detach_4: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_3); detach_3 = None
sum_1: "f32[]" = torch.ops.aten.sum.default(relu)
detach_5: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None
detach_6: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_5); detach_5 = None
detach_7: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_6); detach_6 = None
detach_8: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_7); detach_7 = None
detach_9: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_8); detach_8 = None
detach_10: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_9); detach_9 = None
detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None
ones_like: "f32[]" = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
expand: "f32[1, 3, 3, 3]" = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]); ones_like = None
detach_11: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_4); detach_4 = None
detach_12: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_11); detach_11 = None
detach_13: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_12); detach_12 = None
detach_14: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_13); detach_13 = None
threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_14, 0); expand = detach_14 = None
detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_3, 0); expand = detach_3 = None
native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None
getitem_5: "f32[1, 3, 3, 3]" = native_batch_norm_backward[0]
getitem_6: "f32[3]" = native_batch_norm_backward[1]
@ -5163,7 +5150,7 @@ class <lambda>(torch.nn.Module):
getitem_8 = convolution_backward[0]; getitem_8 = None
getitem_9: "f32[3, 1, 1, 1]" = convolution_backward[1]
getitem_10: "f32[3]" = convolution_backward[2]; convolution_backward = None
return (getitem_3, getitem_4, add, sum_1, detach_10, getitem_9, getitem_10, getitem_6, getitem_7)
return (getitem_3, getitem_4, add, sum_1, detach_2, getitem_9, getitem_10, getitem_6, getitem_7)
""", # noqa: B950
)
@ -5231,14 +5218,12 @@ class <lambda>(torch.nn.Module):
relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None
sum_1: "f32[]" = torch.ops.aten.sum.default(relu)
detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None
detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach); detach = None
detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
return (
getitem_3, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=4))
getitem_4, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=5))
add, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=6))
sum_1, # PlainAOTOutput(idx=0)
detach_2, # PlainAOTOutput(idx=1)
detach, # PlainAOTOutput(idx=1)
)
""", # noqa: B950
)

View File

@ -1174,12 +1174,10 @@ class TestMemoryProfilerE2E(TestCase):
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 20 (AUTOGRAD_DETAIL)
aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT)
aten::view 21 (GRADIENT) -> 21 (GRADIENT)
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
aten::detach 21 (GRADIENT) -> ???
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT)
aten::view 23 (GRADIENT) -> 23 (GRADIENT)
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
aten::detach 23 (GRADIENT) -> ???""",
)
@ -1227,12 +1225,10 @@ class TestMemoryProfilerE2E(TestCase):
aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT)
aten::view 21 (GRADIENT) -> 21 (GRADIENT)
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT)
aten::view 23 (GRADIENT) -> 23 (GRADIENT)
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
-- Optimizer --------------------------------------------------------------------------------------------
aten::add_.Tensor 3 (PARAMETER), 23 (GRADIENT) -> 3 (PARAMETER)
@ -1277,10 +1273,8 @@ class TestMemoryProfilerE2E(TestCase):
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT)
aten::view 9 (GRADIENT) -> 9 (GRADIENT)
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
aten::detach 9 (GRADIENT) -> ???
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
aten::detach 7 (GRADIENT) -> ???""",
)
@ -1318,18 +1312,14 @@ class TestMemoryProfilerE2E(TestCase):
aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT)
aten::view 9 (GRADIENT) -> 9 (GRADIENT)
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
-- Optimizer --------------------------------------------------------------------------------------------
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
aten::clone 7 (GRADIENT) -> 10 (OPTIMIZER_STATE)
aten::add_.Tensor 2 (PARAMETER), 10 (OPTIMIZER_STATE) -> 2 (PARAMETER)
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
aten::clone 9 (GRADIENT) -> 11 (OPTIMIZER_STATE)
aten::add_.Tensor 3 (PARAMETER), 11 (OPTIMIZER_STATE) -> 3 (PARAMETER)""",
)
@ -1414,7 +1404,6 @@ class TestMemoryProfilerE2E(TestCase):
aten::t 7 (PARAMETER) -> 7 (PARAMETER)
aten::mm 25 (AUTOGRAD_DETAIL), 7 (PARAMETER) -> 27 (AUTOGRAD_DETAIL)
aten::t 26 (GRADIENT) -> 26 (GRADIENT)
aten::detach 26 (GRADIENT) -> 26 (GRADIENT)
aten::detach 26 (GRADIENT) -> ???
aten::detach 6 (ACTIVATION) -> 6 (ACTIVATION)
aten::threshold_backward 27 (AUTOGRAD_DETAIL), 6 (ACTIVATION) -> 28 (AUTOGRAD_DETAIL)
@ -1423,10 +1412,8 @@ class TestMemoryProfilerE2E(TestCase):
aten::t 29 (GRADIENT) -> 29 (GRADIENT)
aten::sum.dim_IntList 28 (AUTOGRAD_DETAIL) -> 30 (GRADIENT)
aten::view 30 (GRADIENT) -> 30 (GRADIENT)
aten::detach 30 (GRADIENT) -> 30 (GRADIENT)
aten::detach 30 (GRADIENT) -> ???
aten::t 29 (GRADIENT) -> 29 (GRADIENT)
aten::detach 29 (GRADIENT) -> 29 (GRADIENT)
aten::detach 29 (GRADIENT) -> ???""",
)

View File

@ -5050,7 +5050,6 @@ Running aten.expand.default from within SumBackward0
Running aten.div.Tensor from within DivBackward0
Running aten.mul.Tensor from within MulBackward0
Running aten.detach.default from within AccumulateGrad
Running aten.detach.default from within AccumulateGrad
Done""",
)
@ -7323,9 +7322,7 @@ for shape in [(1,), ()]:
lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn
)
out.backward()
self.assertEqual(
verbose_mode.operators, ["exp.default", "detach.default", "detach.default"]
)
self.assertEqual(verbose_mode.operators, ["exp.default", "detach.default"])
with self.assertRaisesRegex(
Exception, "only supported when use_reentrant=False"

View File

@ -850,7 +850,7 @@ $1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""",
lambda: A(torch.zeros(1)).detach(),
)
def test_detach_appears_twice_when_called_once(self) -> None:
def test_detach_appears_once_when_called_once(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
log_input("x", x)
@ -863,8 +863,7 @@ $1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""",
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten.detach.default($0)
$2: f32[1] = torch._ops.aten.detach.default($1)""",
$1: f32[1] = torch._ops.aten.detach.default($0)""",
)
def test_storage(self) -> None:

View File

@ -453,20 +453,18 @@ static Tensor detach(c10::DispatchKeySet ks, const Tensor& self) {
return at::_ops::detach::redispatch(
ks & c10::after_ADInplaceOrView_keyset, self);
})();
// NB: we can't make detach() a normal view operator because the codegen
// generates allow_tensor_metadata_change = True for them. In the future we
// should have an option for this in the codegen.
auto result = as_view(
/* base */ self,
/* output */ out,
/* is_bw_differentiable */ false,
/* is_fw_differentiable */ false,
/* view_func */ nullptr,
/* rev_view_func */ nullptr,
/* creation_meta */ CreationMeta::DEFAULT,
/*allow_tensor_metadata_change=*/false);
return result;
// NB: we can't make detach() a normal view operator because the
// codegen generates allow_tensor_metadata_change = True (and leaves
// is_fresh_tensor to the default setting of False) for them. In the
// future we should have an option for this in the codegen.
if (self.is_inference()) {
return out;
}
return ::torch::autograd::make_variable_non_differentiable_view(
self,
out,
/* allow_tensor_metadata_change */ false,
/* is_fresh_tensor */ true);
}
static Tensor _fw_primal(

View File

@ -858,11 +858,20 @@ inline Variable make_variable_differentiable_view(
inline Variable make_variable_non_differentiable_view(
const Variable& base,
const at::Tensor& data,
bool allow_tensor_metadata_change = true) {
bool allow_tensor_metadata_change = true,
bool is_fresh_tensor = false) {
if (data.defined()) {
// Currently all of non-differentiable view ops(detach/_indices/_values)
// share the same TensorImpl as their base Tensor. Thus a new TensorImpl
// allocation here is required.
// If we already allocated a new tensor, no need to
// shallow_copy_and_detach here. (See #163671 history; we tried to
// fan out to _indices and _values and ran into a SparseTensorImpl
// can of worms.)
if (is_fresh_tensor) {
auto* data_impl = data.unsafeGetTensorImpl();
data_impl->set_version_counter(impl::version_counter(base));
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
data_impl->set_autograd_meta(nullptr);
return data;
}
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/impl::version_counter(base),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);