mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix double dispatch to Python for detach (#163671)
This fixes #71725. Differential Revision: [D83857880](https://our.internmc.facebook.com/intern/diff/D83857880) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163671 Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
815d641599
commit
331b7cc054
@ -239,9 +239,7 @@ class DTensorExportTest(TestCase):
|
||||
"view_9",
|
||||
"t_15",
|
||||
"detach",
|
||||
"detach_1",
|
||||
"detach_6",
|
||||
"detach_7",
|
||||
"detach_3",
|
||||
"threshold_backward_1",
|
||||
"t_16",
|
||||
"mm_6",
|
||||
@ -259,10 +257,8 @@ class DTensorExportTest(TestCase):
|
||||
"sum_1",
|
||||
"view_7",
|
||||
"t_7",
|
||||
"detach_1",
|
||||
"detach_2",
|
||||
"detach_3",
|
||||
"detach_4",
|
||||
"detach_5",
|
||||
"threshold_backward",
|
||||
"mm_2",
|
||||
"t_9",
|
||||
|
@ -921,7 +921,6 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
1|aten._native_batch_norm_legit_functional.default|batch_norm|
|
||||
2|aten.relu.default|relu|
|
||||
2|aten.detach.default|relu|
|
||||
2|aten.detach.default|relu|
|
||||
3|aten.add.Tensor|add|
|
||||
4|aten.view.default|flatten|
|
||||
5|aten.view.default|linear|
|
||||
@ -948,7 +947,6 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
5|aten.view.default||linear
|
||||
4|aten.view.default||flatten
|
||||
2|aten.detach.default||relu
|
||||
2|aten.detach.default||relu
|
||||
2|aten.threshold_backward.default||relu
|
||||
1|aten.native_batch_norm_backward.default||batch_norm
|
||||
0|aten.convolution_backward.default||conv2d
|
||||
|
@ -216,18 +216,16 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
|
||||
('call_function', 'getitem', {'compile_inductor': 0})
|
||||
('call_function', 'getitem_1', {'compile_inductor': 0})
|
||||
('call_function', 'detach_1', {'compile_inductor': 0})
|
||||
('call_function', 'detach_4', {'compile_inductor': 0})
|
||||
('call_function', 'detach_5', {'compile_inductor': 0})""", # noqa: B950
|
||||
('call_function', 'detach_3', {'compile_inductor': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(bw_metadata),
|
||||
"""\
|
||||
('placeholder', 'getitem', {'compile_inductor': 0})
|
||||
('placeholder', 'detach_5', {'compile_inductor': 0})
|
||||
('placeholder', 'detach_3', {'compile_inductor': 0})
|
||||
('call_function', 'zeros', {'compile_inductor': 0})
|
||||
('call_function', 'detach', {'compile_inductor': 0})
|
||||
('call_function', 'detach_2', {'compile_inductor': 0})
|
||||
('call_function', 'detach_3', {'compile_inductor': 0})
|
||||
('get_attr', 'fw_graph0', {'compile_inductor': 0})
|
||||
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('placeholder', 'arg4_1', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||
('get_attr', 'joint_graph0', {'compile_inductor': 0})
|
||||
|
@ -684,11 +684,11 @@ class StructuredTraceTest(TestCase):
|
||||
{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"describe_storage": {"id": 16, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_tensor": {"id": 29, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 29, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_tensor": {"id": 28, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 28, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_storage": {"id": 17, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_tensor": {"id": 30, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 17, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 30, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_tensor": {"id": 29, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 17, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"describe_source": {"describer_id": "ID", "id": 29, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||
{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
|
@ -45,11 +45,9 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
||||
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
|
||||
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
|
||||
alias = torch.ops.aten.alias.default(_softmax)
|
||||
alias_1 = torch.ops.aten.alias.default(alias); alias = None
|
||||
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
||||
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
|
||||
alias_2 = torch.ops.aten.alias.default(_log_softmax)
|
||||
alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None
|
||||
alias_1 = torch.ops.aten.alias.default(_log_softmax)
|
||||
mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None
|
||||
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
||||
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
||||
@ -59,17 +57,15 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
||||
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
||||
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
||||
alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None
|
||||
alias_5 = torch.ops.aten.alias.default(alias_4); alias_4 = None
|
||||
exp = torch.ops.aten.exp.default(alias_5); alias_5 = None
|
||||
alias_2 = torch.ops.aten.alias.default(alias_1); alias_1 = None
|
||||
exp = torch.ops.aten.exp.default(alias_2); alias_2 = None
|
||||
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
|
||||
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
|
||||
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
|
||||
alias_6 = torch.ops.aten.alias.default(alias_1); alias_1 = None
|
||||
alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None
|
||||
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_7); sub = None
|
||||
alias_3 = torch.ops.aten.alias.default(alias); alias = None
|
||||
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_3); sub = None
|
||||
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
|
||||
mul_4 = torch.ops.aten.mul.Tensor(alias_7, sum_3); alias_7 = sum_3 = None
|
||||
mul_4 = torch.ops.aten.mul.Tensor(alias_3, sum_3); alias_3 = sum_3 = None
|
||||
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
|
||||
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
|
||||
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
|
||||
@ -91,11 +87,9 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
||||
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
|
||||
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
|
||||
alias = torch.ops.aten.alias.default(_softmax)
|
||||
alias_1 = torch.ops.aten.alias.default(alias); alias = None
|
||||
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
||||
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
|
||||
alias_2 = torch.ops.aten.alias.default(_log_softmax)
|
||||
alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None
|
||||
alias_1 = torch.ops.aten.alias.default(_log_softmax)
|
||||
mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None
|
||||
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
||||
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
||||
@ -105,17 +99,15 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
||||
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
||||
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
||||
alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None
|
||||
alias_5 = torch.ops.aten.alias.default(alias_4); alias_4 = None
|
||||
exp = torch.ops.aten.exp.default(alias_5); alias_5 = None
|
||||
alias_2 = torch.ops.aten.alias.default(alias_1); alias_1 = None
|
||||
exp = torch.ops.aten.exp.default(alias_2); alias_2 = None
|
||||
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
|
||||
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
|
||||
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
|
||||
alias_6 = torch.ops.aten.alias.default(alias_1); alias_1 = None
|
||||
alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None
|
||||
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_7); sub = None
|
||||
alias_3 = torch.ops.aten.alias.default(alias); alias = None
|
||||
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_3); sub = None
|
||||
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
|
||||
mul_4 = torch.ops.aten.mul.Tensor(alias_7, sum_3); alias_7 = sum_3 = None
|
||||
mul_4 = torch.ops.aten.mul.Tensor(alias_3, sum_3); alias_3 = sum_3 = None
|
||||
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
|
||||
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
|
||||
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
|
||||
|
@ -1229,9 +1229,7 @@ def forward(self, primals, tangents):
|
||||
t = torch.ops.aten.t.default(primals_1); primals_1 = None
|
||||
addmm = torch.ops.aten.addmm.default(primals_2, primals_5, t); primals_2 = None
|
||||
relu = torch.ops.aten.relu.default(addmm); addmm = None
|
||||
detach_9 = torch.ops.aten.detach.default(relu)
|
||||
detach_10 = torch.ops.aten.detach.default(detach_9); detach_9 = None
|
||||
detach_11 = torch.ops.aten.detach.default(detach_10); detach_10 = None
|
||||
detach_3 = torch.ops.aten.detach.default(relu)
|
||||
t_1 = torch.ops.aten.t.default(primals_3); primals_3 = None
|
||||
addmm_1 = torch.ops.aten.addmm.default(primals_4, relu, t_1); primals_4 = None
|
||||
t_2 = torch.ops.aten.t.default(t_1); t_1 = None
|
||||
@ -1242,9 +1240,8 @@ def forward(self, primals, tangents):
|
||||
sum_1 = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True); tangents_1 = None
|
||||
view = torch.ops.aten.view.default(sum_1, [128]); sum_1 = None
|
||||
t_5 = torch.ops.aten.t.default(t_4); t_4 = None
|
||||
detach_18 = torch.ops.aten.detach.default(detach_11); detach_11 = None
|
||||
detach_19 = torch.ops.aten.detach.default(detach_18); detach_18 = None
|
||||
threshold_backward = torch.ops.aten.threshold_backward.default(mm, detach_19, 0); mm = detach_19 = None
|
||||
detach_6 = torch.ops.aten.detach.default(detach_3); detach_3 = None
|
||||
threshold_backward = torch.ops.aten.threshold_backward.default(mm, detach_6, 0); mm = detach_6 = None
|
||||
t_6 = torch.ops.aten.t.default(t); t = None
|
||||
mm_2 = torch.ops.aten.mm.default(threshold_backward, t_6); t_6 = None
|
||||
t_7 = torch.ops.aten.t.default(threshold_backward)
|
||||
@ -10320,13 +10317,9 @@ graph():
|
||||
%x : [num_users=2] = placeholder[target=x]
|
||||
%ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False})
|
||||
%detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {})
|
||||
%detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach,), kwargs = {})
|
||||
%detach_2 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_1,), kwargs = {})
|
||||
%clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {})
|
||||
%detach_3 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {})
|
||||
%detach_4 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_3,), kwargs = {})
|
||||
%detach_5 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_4,), kwargs = {})
|
||||
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach_2, %detach_5), kwargs = {})
|
||||
%detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {})
|
||||
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach, %detach_1), kwargs = {})
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {})
|
||||
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
|
||||
return (mul_1,)""",
|
||||
|
@ -237,9 +237,7 @@ class inner_f(torch.nn.Module):
|
||||
where: "f32[2, 3, 4, 4]" = torch.ops.prims.where.default(le, 0.0, add_4); le = add_4 = None
|
||||
view_of: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(where)
|
||||
view_of_1: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of); view_of = None
|
||||
view_of_2: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of_1); view_of_1 = None
|
||||
view_of_3: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of_2); view_of_2 = None
|
||||
le_1: "b8[2, 3, 4, 4]" = torch.ops.prims.le.default(view_of_3, 0.0); view_of_3 = None
|
||||
le_1: "b8[2, 3, 4, 4]" = torch.ops.prims.le.default(view_of_1, 0.0); view_of_1 = None
|
||||
where_1: "f32[2, 3, 4, 4]" = torch.ops.prims.where.default(le_1, 0.0, tangents_1); le_1 = tangents_1 = None
|
||||
broadcast_in_dim_10: "f32[1, 3]" = torch.ops.prims.broadcast_in_dim.default(squeeze_2, [1, 3], [1]); squeeze_2 = None
|
||||
broadcast_in_dim_11: "f32[1, 3, 1]" = torch.ops.prims.broadcast_in_dim.default(broadcast_in_dim_10, [1, 3, 1], [0, 1]); broadcast_in_dim_10 = None
|
||||
|
@ -2278,9 +2278,7 @@ def forward(self, primals_1):
|
||||
view = torch.ops.aten.view.default(mul, [-1])
|
||||
select = torch.ops.aten.select.int(mul, 0, 0)
|
||||
detach = torch.ops.aten.detach.default(select); select = None
|
||||
detach_1 = torch.ops.aten.detach.default(detach); detach = None
|
||||
detach_2 = torch.ops.aten.detach.default(detach_1); detach_1 = None
|
||||
return (view, mul, detach_2)""",
|
||||
return (view, mul, detach)""",
|
||||
)
|
||||
|
||||
def test_output_aliases_intermediate_inplace_view(self):
|
||||
@ -5138,23 +5136,12 @@ class <lambda>(torch.nn.Module):
|
||||
relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None
|
||||
detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); detach = None
|
||||
detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu)
|
||||
detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
|
||||
detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_2); detach_2 = None
|
||||
detach_4: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_3); detach_3 = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(relu)
|
||||
detach_5: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None
|
||||
detach_6: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_5); detach_5 = None
|
||||
detach_7: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_6); detach_6 = None
|
||||
detach_8: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_7); detach_7 = None
|
||||
detach_9: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_8); detach_8 = None
|
||||
detach_10: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_9); detach_9 = None
|
||||
detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None
|
||||
ones_like: "f32[]" = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
|
||||
expand: "f32[1, 3, 3, 3]" = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]); ones_like = None
|
||||
detach_11: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_4); detach_4 = None
|
||||
detach_12: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_11); detach_11 = None
|
||||
detach_13: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_12); detach_12 = None
|
||||
detach_14: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_13); detach_13 = None
|
||||
threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_14, 0); expand = detach_14 = None
|
||||
detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
|
||||
threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_3, 0); expand = detach_3 = None
|
||||
native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None
|
||||
getitem_5: "f32[1, 3, 3, 3]" = native_batch_norm_backward[0]
|
||||
getitem_6: "f32[3]" = native_batch_norm_backward[1]
|
||||
@ -5163,7 +5150,7 @@ class <lambda>(torch.nn.Module):
|
||||
getitem_8 = convolution_backward[0]; getitem_8 = None
|
||||
getitem_9: "f32[3, 1, 1, 1]" = convolution_backward[1]
|
||||
getitem_10: "f32[3]" = convolution_backward[2]; convolution_backward = None
|
||||
return (getitem_3, getitem_4, add, sum_1, detach_10, getitem_9, getitem_10, getitem_6, getitem_7)
|
||||
return (getitem_3, getitem_4, add, sum_1, detach_2, getitem_9, getitem_10, getitem_6, getitem_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -5231,14 +5218,12 @@ class <lambda>(torch.nn.Module):
|
||||
relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(relu)
|
||||
detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None
|
||||
detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach); detach = None
|
||||
detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
|
||||
return (
|
||||
getitem_3, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=4))
|
||||
getitem_4, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=5))
|
||||
add, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=6))
|
||||
sum_1, # PlainAOTOutput(idx=0)
|
||||
detach_2, # PlainAOTOutput(idx=1)
|
||||
detach, # PlainAOTOutput(idx=1)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
@ -1174,12 +1174,10 @@ class TestMemoryProfilerE2E(TestCase):
|
||||
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 20 (AUTOGRAD_DETAIL)
|
||||
aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT)
|
||||
aten::view 21 (GRADIENT) -> 21 (GRADIENT)
|
||||
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
|
||||
aten::detach 21 (GRADIENT) -> ???
|
||||
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
|
||||
aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT)
|
||||
aten::view 23 (GRADIENT) -> 23 (GRADIENT)
|
||||
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
|
||||
aten::detach 23 (GRADIENT) -> ???""",
|
||||
)
|
||||
|
||||
@ -1227,12 +1225,10 @@ class TestMemoryProfilerE2E(TestCase):
|
||||
aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT)
|
||||
aten::view 21 (GRADIENT) -> 21 (GRADIENT)
|
||||
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
|
||||
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
|
||||
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
|
||||
aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT)
|
||||
aten::view 23 (GRADIENT) -> 23 (GRADIENT)
|
||||
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
|
||||
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
|
||||
|
||||
-- Optimizer --------------------------------------------------------------------------------------------
|
||||
aten::add_.Tensor 3 (PARAMETER), 23 (GRADIENT) -> 3 (PARAMETER)
|
||||
@ -1277,10 +1273,8 @@ class TestMemoryProfilerE2E(TestCase):
|
||||
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
|
||||
aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT)
|
||||
aten::view 9 (GRADIENT) -> 9 (GRADIENT)
|
||||
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
|
||||
aten::detach 9 (GRADIENT) -> ???
|
||||
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
|
||||
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
|
||||
aten::detach 7 (GRADIENT) -> ???""",
|
||||
)
|
||||
|
||||
@ -1318,18 +1312,14 @@ class TestMemoryProfilerE2E(TestCase):
|
||||
aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT)
|
||||
aten::view 9 (GRADIENT) -> 9 (GRADIENT)
|
||||
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
|
||||
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
|
||||
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
|
||||
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
|
||||
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
|
||||
|
||||
-- Optimizer --------------------------------------------------------------------------------------------
|
||||
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
|
||||
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
|
||||
aten::clone 7 (GRADIENT) -> 10 (OPTIMIZER_STATE)
|
||||
aten::add_.Tensor 2 (PARAMETER), 10 (OPTIMIZER_STATE) -> 2 (PARAMETER)
|
||||
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
|
||||
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
|
||||
aten::clone 9 (GRADIENT) -> 11 (OPTIMIZER_STATE)
|
||||
aten::add_.Tensor 3 (PARAMETER), 11 (OPTIMIZER_STATE) -> 3 (PARAMETER)""",
|
||||
)
|
||||
@ -1414,7 +1404,6 @@ class TestMemoryProfilerE2E(TestCase):
|
||||
aten::t 7 (PARAMETER) -> 7 (PARAMETER)
|
||||
aten::mm 25 (AUTOGRAD_DETAIL), 7 (PARAMETER) -> 27 (AUTOGRAD_DETAIL)
|
||||
aten::t 26 (GRADIENT) -> 26 (GRADIENT)
|
||||
aten::detach 26 (GRADIENT) -> 26 (GRADIENT)
|
||||
aten::detach 26 (GRADIENT) -> ???
|
||||
aten::detach 6 (ACTIVATION) -> 6 (ACTIVATION)
|
||||
aten::threshold_backward 27 (AUTOGRAD_DETAIL), 6 (ACTIVATION) -> 28 (AUTOGRAD_DETAIL)
|
||||
@ -1423,10 +1412,8 @@ class TestMemoryProfilerE2E(TestCase):
|
||||
aten::t 29 (GRADIENT) -> 29 (GRADIENT)
|
||||
aten::sum.dim_IntList 28 (AUTOGRAD_DETAIL) -> 30 (GRADIENT)
|
||||
aten::view 30 (GRADIENT) -> 30 (GRADIENT)
|
||||
aten::detach 30 (GRADIENT) -> 30 (GRADIENT)
|
||||
aten::detach 30 (GRADIENT) -> ???
|
||||
aten::t 29 (GRADIENT) -> 29 (GRADIENT)
|
||||
aten::detach 29 (GRADIENT) -> 29 (GRADIENT)
|
||||
aten::detach 29 (GRADIENT) -> ???""",
|
||||
)
|
||||
|
||||
|
@ -5050,7 +5050,6 @@ Running aten.expand.default from within SumBackward0
|
||||
Running aten.div.Tensor from within DivBackward0
|
||||
Running aten.mul.Tensor from within MulBackward0
|
||||
Running aten.detach.default from within AccumulateGrad
|
||||
Running aten.detach.default from within AccumulateGrad
|
||||
Done""",
|
||||
)
|
||||
|
||||
@ -7323,9 +7322,7 @@ for shape in [(1,), ()]:
|
||||
lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn
|
||||
)
|
||||
out.backward()
|
||||
self.assertEqual(
|
||||
verbose_mode.operators, ["exp.default", "detach.default", "detach.default"]
|
||||
)
|
||||
self.assertEqual(verbose_mode.operators, ["exp.default", "detach.default"])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "only supported when use_reentrant=False"
|
||||
|
@ -850,7 +850,7 @@ $1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""",
|
||||
lambda: A(torch.zeros(1)).detach(),
|
||||
)
|
||||
|
||||
def test_detach_appears_twice_when_called_once(self) -> None:
|
||||
def test_detach_appears_once_when_called_once(self) -> None:
|
||||
with capture_logs() as logs:
|
||||
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
|
||||
log_input("x", x)
|
||||
@ -863,8 +863,7 @@ $1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""",
|
||||
"\n".join(logs),
|
||||
"""\
|
||||
$0: f32[1] = input('x')
|
||||
$1: f32[1] = torch._ops.aten.detach.default($0)
|
||||
$2: f32[1] = torch._ops.aten.detach.default($1)""",
|
||||
$1: f32[1] = torch._ops.aten.detach.default($0)""",
|
||||
)
|
||||
|
||||
def test_storage(self) -> None:
|
||||
|
@ -453,20 +453,18 @@ static Tensor detach(c10::DispatchKeySet ks, const Tensor& self) {
|
||||
return at::_ops::detach::redispatch(
|
||||
ks & c10::after_ADInplaceOrView_keyset, self);
|
||||
})();
|
||||
// NB: we can't make detach() a normal view operator because the codegen
|
||||
// generates allow_tensor_metadata_change = True for them. In the future we
|
||||
// should have an option for this in the codegen.
|
||||
auto result = as_view(
|
||||
/* base */ self,
|
||||
/* output */ out,
|
||||
/* is_bw_differentiable */ false,
|
||||
/* is_fw_differentiable */ false,
|
||||
/* view_func */ nullptr,
|
||||
/* rev_view_func */ nullptr,
|
||||
/* creation_meta */ CreationMeta::DEFAULT,
|
||||
/*allow_tensor_metadata_change=*/false);
|
||||
|
||||
return result;
|
||||
// NB: we can't make detach() a normal view operator because the
|
||||
// codegen generates allow_tensor_metadata_change = True (and leaves
|
||||
// is_fresh_tensor to the default setting of False) for them. In the
|
||||
// future we should have an option for this in the codegen.
|
||||
if (self.is_inference()) {
|
||||
return out;
|
||||
}
|
||||
return ::torch::autograd::make_variable_non_differentiable_view(
|
||||
self,
|
||||
out,
|
||||
/* allow_tensor_metadata_change */ false,
|
||||
/* is_fresh_tensor */ true);
|
||||
}
|
||||
|
||||
static Tensor _fw_primal(
|
||||
|
@ -858,11 +858,20 @@ inline Variable make_variable_differentiable_view(
|
||||
inline Variable make_variable_non_differentiable_view(
|
||||
const Variable& base,
|
||||
const at::Tensor& data,
|
||||
bool allow_tensor_metadata_change = true) {
|
||||
bool allow_tensor_metadata_change = true,
|
||||
bool is_fresh_tensor = false) {
|
||||
if (data.defined()) {
|
||||
// Currently all of non-differentiable view ops(detach/_indices/_values)
|
||||
// share the same TensorImpl as their base Tensor. Thus a new TensorImpl
|
||||
// allocation here is required.
|
||||
// If we already allocated a new tensor, no need to
|
||||
// shallow_copy_and_detach here. (See #163671 history; we tried to
|
||||
// fan out to _indices and _values and ran into a SparseTensorImpl
|
||||
// can of worms.)
|
||||
if (is_fresh_tensor) {
|
||||
auto* data_impl = data.unsafeGetTensorImpl();
|
||||
data_impl->set_version_counter(impl::version_counter(base));
|
||||
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
data_impl->set_autograd_meta(nullptr);
|
||||
return data;
|
||||
}
|
||||
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
|
||||
/*version_counter=*/impl::version_counter(base),
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||
|
Reference in New Issue
Block a user