Fix double dispatch to Python for detach (#163671)

This fixes #71725.

Differential Revision: [D83857880](https://our.internmc.facebook.com/intern/diff/D83857880)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163671
Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
Scott Wolchok
2025-10-12 12:55:37 -07:00
committed by PyTorch MergeBot
parent 6bda3bb286
commit a3e3efe474
13 changed files with 60 additions and 110 deletions

View File

@ -239,9 +239,7 @@ class DTensorExportTest(TestCase):
"view_9", "view_9",
"t_15", "t_15",
"detach", "detach",
"detach_1", "detach_3",
"detach_6",
"detach_7",
"threshold_backward_1", "threshold_backward_1",
"t_16", "t_16",
"mm_6", "mm_6",
@ -259,10 +257,8 @@ class DTensorExportTest(TestCase):
"sum_1", "sum_1",
"view_7", "view_7",
"t_7", "t_7",
"detach_1",
"detach_2", "detach_2",
"detach_3",
"detach_4",
"detach_5",
"threshold_backward", "threshold_backward",
"mm_2", "mm_2",
"t_9", "t_9",

View File

@ -921,7 +921,6 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
1|aten._native_batch_norm_legit_functional.default|batch_norm| 1|aten._native_batch_norm_legit_functional.default|batch_norm|
2|aten.relu.default|relu| 2|aten.relu.default|relu|
2|aten.detach.default|relu| 2|aten.detach.default|relu|
2|aten.detach.default|relu|
3|aten.add.Tensor|add| 3|aten.add.Tensor|add|
4|aten.view.default|flatten| 4|aten.view.default|flatten|
5|aten.view.default|linear| 5|aten.view.default|linear|
@ -948,7 +947,6 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
5|aten.view.default||linear 5|aten.view.default||linear
4|aten.view.default||flatten 4|aten.view.default||flatten
2|aten.detach.default||relu 2|aten.detach.default||relu
2|aten.detach.default||relu
2|aten.threshold_backward.default||relu 2|aten.threshold_backward.default||relu
1|aten.native_batch_norm_backward.default||batch_norm 1|aten.native_batch_norm_backward.default||batch_norm
0|aten.convolution_backward.default||conv2d 0|aten.convolution_backward.default||conv2d

View File

@ -241,18 +241,16 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
('call_function', 'getitem', {'compile_inductor': 0}) ('call_function', 'getitem', {'compile_inductor': 0})
('call_function', 'getitem_1', {'compile_inductor': 0}) ('call_function', 'getitem_1', {'compile_inductor': 0})
('call_function', 'detach_1', {'compile_inductor': 0}) ('call_function', 'detach_1', {'compile_inductor': 0})
('call_function', 'detach_4', {'compile_inductor': 0}) ('call_function', 'detach_3', {'compile_inductor': 0})""", # noqa: B950
('call_function', 'detach_5', {'compile_inductor': 0})""", # noqa: B950
) )
self.assertExpectedInline( self.assertExpectedInline(
str(bw_metadata), str(bw_metadata),
"""\ """\
('placeholder', 'getitem', {'compile_inductor': 0}) ('placeholder', 'getitem', {'compile_inductor': 0})
('placeholder', 'detach_5', {'compile_inductor': 0}) ('placeholder', 'detach_3', {'compile_inductor': 0})
('call_function', 'zeros', {'compile_inductor': 0}) ('call_function', 'zeros', {'compile_inductor': 0})
('call_function', 'detach', {'compile_inductor': 0}) ('call_function', 'detach', {'compile_inductor': 0})
('call_function', 'detach_2', {'compile_inductor': 0}) ('call_function', 'detach_2', {'compile_inductor': 0})
('call_function', 'detach_3', {'compile_inductor': 0})
('get_attr', 'fw_graph0', {'compile_inductor': 0}) ('get_attr', 'fw_graph0', {'compile_inductor': 0})
[] []
('get_attr', 'joint_graph0', {'compile_inductor': 0}) ('get_attr', 'joint_graph0', {'compile_inductor': 0})

View File

@ -684,11 +684,11 @@ class StructuredTraceTest(TestCase):
{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"describe_storage": {"id": 16, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 16, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_tensor": {"id": 29, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 28, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 29, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 28, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_storage": {"id": 17, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 17, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_tensor": {"id": 30, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 17, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 29, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 17, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 30, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 29, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}

View File

@ -45,11 +45,9 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None _softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
alias = torch.ops.aten.alias.default(_softmax) alias = torch.ops.aten.alias.default(_softmax)
alias_1 = torch.ops.aten.alias.default(alias); alias = None
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None _log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
alias_2 = torch.ops.aten.alias.default(_log_softmax) alias_1 = torch.ops.aten.alias.default(_log_softmax)
alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None
mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
@ -59,17 +57,15 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None alias_2 = torch.ops.aten.alias.default(alias_1); alias_1 = None
alias_5 = torch.ops.aten.alias.default(alias_4); alias_4 = None exp = torch.ops.aten.exp.default(alias_2); alias_2 = None
exp = torch.ops.aten.exp.default(alias_5); alias_5 = None
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True) sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
alias_6 = torch.ops.aten.alias.default(alias_1); alias_1 = None alias_3 = torch.ops.aten.alias.default(alias); alias = None
alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None mul_3 = torch.ops.aten.mul.Tensor(sub, alias_3); sub = None
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_7); sub = None
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True) sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
mul_4 = torch.ops.aten.mul.Tensor(alias_7, sum_3); alias_7 = sum_3 = None mul_4 = torch.ops.aten.mul.Tensor(alias_3, sum_3); alias_3 = sum_3 = None
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0]) permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
@ -91,11 +87,9 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None _softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
alias = torch.ops.aten.alias.default(_softmax) alias = torch.ops.aten.alias.default(_softmax)
alias_1 = torch.ops.aten.alias.default(alias); alias = None
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None _log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
alias_2 = torch.ops.aten.alias.default(_log_softmax) alias_1 = torch.ops.aten.alias.default(_log_softmax)
alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None
mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
@ -105,17 +99,15 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None alias_2 = torch.ops.aten.alias.default(alias_1); alias_1 = None
alias_5 = torch.ops.aten.alias.default(alias_4); alias_4 = None exp = torch.ops.aten.exp.default(alias_2); alias_2 = None
exp = torch.ops.aten.exp.default(alias_5); alias_5 = None
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True) sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
alias_6 = torch.ops.aten.alias.default(alias_1); alias_1 = None alias_3 = torch.ops.aten.alias.default(alias); alias = None
alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None mul_3 = torch.ops.aten.mul.Tensor(sub, alias_3); sub = None
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_7); sub = None
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True) sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
mul_4 = torch.ops.aten.mul.Tensor(alias_7, sum_3); alias_7 = sum_3 = None mul_4 = torch.ops.aten.mul.Tensor(alias_3, sum_3); alias_3 = sum_3 = None
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0]) permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])

View File

@ -1229,9 +1229,7 @@ def forward(self, primals, tangents):
t = torch.ops.aten.t.default(primals_1); primals_1 = None t = torch.ops.aten.t.default(primals_1); primals_1 = None
addmm = torch.ops.aten.addmm.default(primals_2, primals_5, t); primals_2 = None addmm = torch.ops.aten.addmm.default(primals_2, primals_5, t); primals_2 = None
relu = torch.ops.aten.relu.default(addmm); addmm = None relu = torch.ops.aten.relu.default(addmm); addmm = None
detach_9 = torch.ops.aten.detach.default(relu) detach_3 = torch.ops.aten.detach.default(relu)
detach_10 = torch.ops.aten.detach.default(detach_9); detach_9 = None
detach_11 = torch.ops.aten.detach.default(detach_10); detach_10 = None
t_1 = torch.ops.aten.t.default(primals_3); primals_3 = None t_1 = torch.ops.aten.t.default(primals_3); primals_3 = None
addmm_1 = torch.ops.aten.addmm.default(primals_4, relu, t_1); primals_4 = None addmm_1 = torch.ops.aten.addmm.default(primals_4, relu, t_1); primals_4 = None
t_2 = torch.ops.aten.t.default(t_1); t_1 = None t_2 = torch.ops.aten.t.default(t_1); t_1 = None
@ -1242,9 +1240,8 @@ def forward(self, primals, tangents):
sum_1 = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True); tangents_1 = None sum_1 = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True); tangents_1 = None
view = torch.ops.aten.view.default(sum_1, [128]); sum_1 = None view = torch.ops.aten.view.default(sum_1, [128]); sum_1 = None
t_5 = torch.ops.aten.t.default(t_4); t_4 = None t_5 = torch.ops.aten.t.default(t_4); t_4 = None
detach_18 = torch.ops.aten.detach.default(detach_11); detach_11 = None detach_6 = torch.ops.aten.detach.default(detach_3); detach_3 = None
detach_19 = torch.ops.aten.detach.default(detach_18); detach_18 = None threshold_backward = torch.ops.aten.threshold_backward.default(mm, detach_6, 0); mm = detach_6 = None
threshold_backward = torch.ops.aten.threshold_backward.default(mm, detach_19, 0); mm = detach_19 = None
t_6 = torch.ops.aten.t.default(t); t = None t_6 = torch.ops.aten.t.default(t); t = None
mm_2 = torch.ops.aten.mm.default(threshold_backward, t_6); t_6 = None mm_2 = torch.ops.aten.mm.default(threshold_backward, t_6); t_6 = None
t_7 = torch.ops.aten.t.default(threshold_backward) t_7 = torch.ops.aten.t.default(threshold_backward)
@ -10302,13 +10299,9 @@ graph():
%x : [num_users=2] = placeholder[target=x] %x : [num_users=2] = placeholder[target=x]
%ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False}) %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False})
%detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {}) %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {})
%detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach,), kwargs = {})
%detach_2 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_1,), kwargs = {})
%clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {}) %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {})
%detach_3 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {}) %detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {})
%detach_4 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_3,), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach, %detach_1), kwargs = {})
%detach_5 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_4,), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach_2, %detach_5), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
return (mul_1,)""", return (mul_1,)""",

View File

@ -214,9 +214,7 @@ class inner_f(torch.nn.Module):
where: "f32[2, 3, 4, 4]" = torch.ops.prims.where.default(le, 0.0, add_4); le = add_4 = None where: "f32[2, 3, 4, 4]" = torch.ops.prims.where.default(le, 0.0, add_4); le = add_4 = None
view_of: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(where) view_of: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(where)
view_of_1: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of); view_of = None view_of_1: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of); view_of = None
view_of_2: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of_1); view_of_1 = None le_1: "b8[2, 3, 4, 4]" = torch.ops.prims.le.default(view_of_1, 0.0); view_of_1 = None
view_of_3: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of_2); view_of_2 = None
le_1: "b8[2, 3, 4, 4]" = torch.ops.prims.le.default(view_of_3, 0.0); view_of_3 = None
where_1: "f32[2, 3, 4, 4]" = torch.ops.prims.where.default(le_1, 0.0, tangents_1); le_1 = tangents_1 = None where_1: "f32[2, 3, 4, 4]" = torch.ops.prims.where.default(le_1, 0.0, tangents_1); le_1 = tangents_1 = None
broadcast_in_dim_10: "f32[1, 3]" = torch.ops.prims.broadcast_in_dim.default(squeeze_2, [1, 3], [1]); squeeze_2 = None broadcast_in_dim_10: "f32[1, 3]" = torch.ops.prims.broadcast_in_dim.default(squeeze_2, [1, 3], [1]); squeeze_2 = None
broadcast_in_dim_11: "f32[1, 3, 1]" = torch.ops.prims.broadcast_in_dim.default(broadcast_in_dim_10, [1, 3, 1], [0, 1]); broadcast_in_dim_10 = None broadcast_in_dim_11: "f32[1, 3, 1]" = torch.ops.prims.broadcast_in_dim.default(broadcast_in_dim_10, [1, 3, 1], [0, 1]); broadcast_in_dim_10 = None

View File

@ -2278,9 +2278,7 @@ def forward(self, primals_1):
view = torch.ops.aten.view.default(mul, [-1]) view = torch.ops.aten.view.default(mul, [-1])
select = torch.ops.aten.select.int(mul, 0, 0) select = torch.ops.aten.select.int(mul, 0, 0)
detach = torch.ops.aten.detach.default(select); select = None detach = torch.ops.aten.detach.default(select); select = None
detach_1 = torch.ops.aten.detach.default(detach); detach = None return (view, mul, detach)""",
detach_2 = torch.ops.aten.detach.default(detach_1); detach_1 = None
return (view, mul, detach_2)""",
) )
def test_output_aliases_intermediate_inplace_view(self): def test_output_aliases_intermediate_inplace_view(self):
@ -5138,23 +5136,12 @@ class <lambda>(torch.nn.Module):
relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None
detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); detach = None detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); detach = None
detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu) detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu)
detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_2); detach_2 = None
detach_4: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_3); detach_3 = None
sum_1: "f32[]" = torch.ops.aten.sum.default(relu) sum_1: "f32[]" = torch.ops.aten.sum.default(relu)
detach_5: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None
detach_6: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_5); detach_5 = None
detach_7: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_6); detach_6 = None
detach_8: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_7); detach_7 = None
detach_9: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_8); detach_8 = None
detach_10: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_9); detach_9 = None
ones_like: "f32[]" = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format) ones_like: "f32[]" = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
expand: "f32[1, 3, 3, 3]" = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]); ones_like = None expand: "f32[1, 3, 3, 3]" = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]); ones_like = None
detach_11: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_4); detach_4 = None detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
detach_12: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_11); detach_11 = None threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_3, 0); expand = detach_3 = None
detach_13: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_12); detach_12 = None
detach_14: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_13); detach_13 = None
threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_14, 0); expand = detach_14 = None
native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None
getitem_5: "f32[1, 3, 3, 3]" = native_batch_norm_backward[0] getitem_5: "f32[1, 3, 3, 3]" = native_batch_norm_backward[0]
getitem_6: "f32[3]" = native_batch_norm_backward[1] getitem_6: "f32[3]" = native_batch_norm_backward[1]
@ -5163,7 +5150,7 @@ class <lambda>(torch.nn.Module):
getitem_8 = convolution_backward[0]; getitem_8 = None getitem_8 = convolution_backward[0]; getitem_8 = None
getitem_9: "f32[3, 1, 1, 1]" = convolution_backward[1] getitem_9: "f32[3, 1, 1, 1]" = convolution_backward[1]
getitem_10: "f32[3]" = convolution_backward[2]; convolution_backward = None getitem_10: "f32[3]" = convolution_backward[2]; convolution_backward = None
return (getitem_3, getitem_4, add, sum_1, detach_10, getitem_9, getitem_10, getitem_6, getitem_7) return (getitem_3, getitem_4, add, sum_1, detach_2, getitem_9, getitem_10, getitem_6, getitem_7)
""", # noqa: B950 """, # noqa: B950
) )
@ -5231,14 +5218,12 @@ class <lambda>(torch.nn.Module):
relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None
sum_1: "f32[]" = torch.ops.aten.sum.default(relu) sum_1: "f32[]" = torch.ops.aten.sum.default(relu)
detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None
detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach); detach = None
detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
return ( return (
getitem_3, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=4)) getitem_3, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=4))
getitem_4, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=5)) getitem_4, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=5))
add, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=6)) add, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=6))
sum_1, # PlainAOTOutput(idx=0) sum_1, # PlainAOTOutput(idx=0)
detach_2, # PlainAOTOutput(idx=1) detach, # PlainAOTOutput(idx=1)
) )
""", # noqa: B950 """, # noqa: B950
) )

View File

@ -1174,12 +1174,10 @@ class TestMemoryProfilerE2E(TestCase):
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 20 (AUTOGRAD_DETAIL) aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 20 (AUTOGRAD_DETAIL)
aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT) aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT)
aten::view 21 (GRADIENT) -> 21 (GRADIENT) aten::view 21 (GRADIENT) -> 21 (GRADIENT)
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
aten::detach 21 (GRADIENT) -> ??? aten::detach 21 (GRADIENT) -> ???
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL) aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT) aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT)
aten::view 23 (GRADIENT) -> 23 (GRADIENT) aten::view 23 (GRADIENT) -> 23 (GRADIENT)
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
aten::detach 23 (GRADIENT) -> ???""", aten::detach 23 (GRADIENT) -> ???""",
) )
@ -1227,12 +1225,10 @@ class TestMemoryProfilerE2E(TestCase):
aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT) aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT)
aten::view 21 (GRADIENT) -> 21 (GRADIENT) aten::view 21 (GRADIENT) -> 21 (GRADIENT)
aten::detach 21 (GRADIENT) -> 21 (GRADIENT) aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL) aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT) aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT)
aten::view 23 (GRADIENT) -> 23 (GRADIENT) aten::view 23 (GRADIENT) -> 23 (GRADIENT)
aten::detach 23 (GRADIENT) -> 23 (GRADIENT) aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
-- Optimizer -------------------------------------------------------------------------------------------- -- Optimizer --------------------------------------------------------------------------------------------
aten::add_.Tensor 3 (PARAMETER), 23 (GRADIENT) -> 3 (PARAMETER) aten::add_.Tensor 3 (PARAMETER), 23 (GRADIENT) -> 3 (PARAMETER)
@ -1277,10 +1273,8 @@ class TestMemoryProfilerE2E(TestCase):
aten::t 7 (GRADIENT) -> 7 (GRADIENT) aten::t 7 (GRADIENT) -> 7 (GRADIENT)
aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT) aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT)
aten::view 9 (GRADIENT) -> 9 (GRADIENT) aten::view 9 (GRADIENT) -> 9 (GRADIENT)
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
aten::detach 9 (GRADIENT) -> ??? aten::detach 9 (GRADIENT) -> ???
aten::t 7 (GRADIENT) -> 7 (GRADIENT) aten::t 7 (GRADIENT) -> 7 (GRADIENT)
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
aten::detach 7 (GRADIENT) -> ???""", aten::detach 7 (GRADIENT) -> ???""",
) )
@ -1318,18 +1312,14 @@ class TestMemoryProfilerE2E(TestCase):
aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT) aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT)
aten::view 9 (GRADIENT) -> 9 (GRADIENT) aten::view 9 (GRADIENT) -> 9 (GRADIENT)
aten::detach 9 (GRADIENT) -> 9 (GRADIENT) aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
aten::t 7 (GRADIENT) -> 7 (GRADIENT) aten::t 7 (GRADIENT) -> 7 (GRADIENT)
aten::detach 7 (GRADIENT) -> 7 (GRADIENT) aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
-- Optimizer -------------------------------------------------------------------------------------------- -- Optimizer --------------------------------------------------------------------------------------------
aten::detach 7 (GRADIENT) -> 7 (GRADIENT) aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
aten::clone 7 (GRADIENT) -> 10 (OPTIMIZER_STATE) aten::clone 7 (GRADIENT) -> 10 (OPTIMIZER_STATE)
aten::add_.Tensor 2 (PARAMETER), 10 (OPTIMIZER_STATE) -> 2 (PARAMETER) aten::add_.Tensor 2 (PARAMETER), 10 (OPTIMIZER_STATE) -> 2 (PARAMETER)
aten::detach 9 (GRADIENT) -> 9 (GRADIENT) aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
aten::clone 9 (GRADIENT) -> 11 (OPTIMIZER_STATE) aten::clone 9 (GRADIENT) -> 11 (OPTIMIZER_STATE)
aten::add_.Tensor 3 (PARAMETER), 11 (OPTIMIZER_STATE) -> 3 (PARAMETER)""", aten::add_.Tensor 3 (PARAMETER), 11 (OPTIMIZER_STATE) -> 3 (PARAMETER)""",
) )
@ -1414,7 +1404,6 @@ class TestMemoryProfilerE2E(TestCase):
aten::t 7 (PARAMETER) -> 7 (PARAMETER) aten::t 7 (PARAMETER) -> 7 (PARAMETER)
aten::mm 25 (AUTOGRAD_DETAIL), 7 (PARAMETER) -> 27 (AUTOGRAD_DETAIL) aten::mm 25 (AUTOGRAD_DETAIL), 7 (PARAMETER) -> 27 (AUTOGRAD_DETAIL)
aten::t 26 (GRADIENT) -> 26 (GRADIENT) aten::t 26 (GRADIENT) -> 26 (GRADIENT)
aten::detach 26 (GRADIENT) -> 26 (GRADIENT)
aten::detach 26 (GRADIENT) -> ??? aten::detach 26 (GRADIENT) -> ???
aten::detach 6 (ACTIVATION) -> 6 (ACTIVATION) aten::detach 6 (ACTIVATION) -> 6 (ACTIVATION)
aten::threshold_backward 27 (AUTOGRAD_DETAIL), 6 (ACTIVATION) -> 28 (AUTOGRAD_DETAIL) aten::threshold_backward 27 (AUTOGRAD_DETAIL), 6 (ACTIVATION) -> 28 (AUTOGRAD_DETAIL)
@ -1423,10 +1412,8 @@ class TestMemoryProfilerE2E(TestCase):
aten::t 29 (GRADIENT) -> 29 (GRADIENT) aten::t 29 (GRADIENT) -> 29 (GRADIENT)
aten::sum.dim_IntList 28 (AUTOGRAD_DETAIL) -> 30 (GRADIENT) aten::sum.dim_IntList 28 (AUTOGRAD_DETAIL) -> 30 (GRADIENT)
aten::view 30 (GRADIENT) -> 30 (GRADIENT) aten::view 30 (GRADIENT) -> 30 (GRADIENT)
aten::detach 30 (GRADIENT) -> 30 (GRADIENT)
aten::detach 30 (GRADIENT) -> ??? aten::detach 30 (GRADIENT) -> ???
aten::t 29 (GRADIENT) -> 29 (GRADIENT) aten::t 29 (GRADIENT) -> 29 (GRADIENT)
aten::detach 29 (GRADIENT) -> 29 (GRADIENT)
aten::detach 29 (GRADIENT) -> ???""", aten::detach 29 (GRADIENT) -> ???""",
) )

View File

@ -5050,7 +5050,6 @@ Running aten.expand.default from within SumBackward0
Running aten.div.Tensor from within DivBackward0 Running aten.div.Tensor from within DivBackward0
Running aten.mul.Tensor from within MulBackward0 Running aten.mul.Tensor from within MulBackward0
Running aten.detach.default from within AccumulateGrad Running aten.detach.default from within AccumulateGrad
Running aten.detach.default from within AccumulateGrad
Done""", Done""",
) )
@ -7323,9 +7322,7 @@ for shape in [(1,), ()]:
lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn
) )
out.backward() out.backward()
self.assertEqual( self.assertEqual(verbose_mode.operators, ["exp.default", "detach.default"])
verbose_mode.operators, ["exp.default", "detach.default", "detach.default"]
)
with self.assertRaisesRegex( with self.assertRaisesRegex(
Exception, "only supported when use_reentrant=False" Exception, "only supported when use_reentrant=False"

View File

@ -850,7 +850,7 @@ $1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""",
lambda: A(torch.zeros(1)).detach(), lambda: A(torch.zeros(1)).detach(),
) )
def test_detach_appears_twice_when_called_once(self) -> None: def test_detach_appears_once_when_called_once(self) -> None:
with capture_logs() as logs: with capture_logs() as logs:
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True) x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
log_input("x", x) log_input("x", x)
@ -863,8 +863,7 @@ $1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""",
"\n".join(logs), "\n".join(logs),
"""\ """\
$0: f32[1] = input('x') $0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten.detach.default($0) $1: f32[1] = torch._ops.aten.detach.default($0)""",
$2: f32[1] = torch._ops.aten.detach.default($1)""",
) )
def test_storage(self) -> None: def test_storage(self) -> None:

View File

@ -453,20 +453,18 @@ static Tensor detach(c10::DispatchKeySet ks, const Tensor& self) {
return at::_ops::detach::redispatch( return at::_ops::detach::redispatch(
ks & c10::after_ADInplaceOrView_keyset, self); ks & c10::after_ADInplaceOrView_keyset, self);
})(); })();
// NB: we can't make detach() a normal view operator because the codegen // NB: we can't make detach() a normal view operator because the
// generates allow_tensor_metadata_change = True for them. In the future we // codegen generates allow_tensor_metadata_change = True (and leaves
// should have an option for this in the codegen. // is_fresh_tensor to the default setting of False) for them. In the
auto result = as_view( // future we should have an option for this in the codegen.
/* base */ self, if (self.is_inference()) {
/* output */ out, return out;
/* is_bw_differentiable */ false, }
/* is_fw_differentiable */ false, return ::torch::autograd::make_variable_non_differentiable_view(
/* view_func */ nullptr, self,
/* rev_view_func */ nullptr, out,
/* creation_meta */ CreationMeta::DEFAULT, /* allow_tensor_metadata_change */ false,
/*allow_tensor_metadata_change=*/false); /* is_fresh_tensor */ true);
return result;
} }
static Tensor _fw_primal( static Tensor _fw_primal(

View File

@ -849,11 +849,20 @@ inline Variable make_variable_differentiable_view(
inline Variable make_variable_non_differentiable_view( inline Variable make_variable_non_differentiable_view(
const Variable& base, const Variable& base,
const at::Tensor& data, const at::Tensor& data,
bool allow_tensor_metadata_change = true) { bool allow_tensor_metadata_change = true,
bool is_fresh_tensor = false) {
if (data.defined()) { if (data.defined()) {
// Currently all of non-differentiable view ops(detach/_indices/_values) // If we already allocated a new tensor, no need to
// share the same TensorImpl as their base Tensor. Thus a new TensorImpl // shallow_copy_and_detach here. (See #163671 history; we tried to
// allocation here is required. // fan out to _indices and _values and ran into a SparseTensorImpl
// can of worms.)
if (is_fresh_tensor) {
auto* data_impl = data.unsafeGetTensorImpl();
data_impl->set_version_counter(impl::version_counter(base));
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
data_impl->set_autograd_meta(nullptr);
return data;
}
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/impl::version_counter(base), /*version_counter=*/impl::version_counter(base),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change); /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);