mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix double dispatch to Python for detach (#163671)
This fixes #71725. Differential Revision: [D83857880](https://our.internmc.facebook.com/intern/diff/D83857880) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163671 Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
6bda3bb286
commit
a3e3efe474
@ -239,9 +239,7 @@ class DTensorExportTest(TestCase):
|
|||||||
"view_9",
|
"view_9",
|
||||||
"t_15",
|
"t_15",
|
||||||
"detach",
|
"detach",
|
||||||
"detach_1",
|
"detach_3",
|
||||||
"detach_6",
|
|
||||||
"detach_7",
|
|
||||||
"threshold_backward_1",
|
"threshold_backward_1",
|
||||||
"t_16",
|
"t_16",
|
||||||
"mm_6",
|
"mm_6",
|
||||||
@ -259,10 +257,8 @@ class DTensorExportTest(TestCase):
|
|||||||
"sum_1",
|
"sum_1",
|
||||||
"view_7",
|
"view_7",
|
||||||
"t_7",
|
"t_7",
|
||||||
|
"detach_1",
|
||||||
"detach_2",
|
"detach_2",
|
||||||
"detach_3",
|
|
||||||
"detach_4",
|
|
||||||
"detach_5",
|
|
||||||
"threshold_backward",
|
"threshold_backward",
|
||||||
"mm_2",
|
"mm_2",
|
||||||
"t_9",
|
"t_9",
|
||||||
|
@ -921,7 +921,6 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
|||||||
1|aten._native_batch_norm_legit_functional.default|batch_norm|
|
1|aten._native_batch_norm_legit_functional.default|batch_norm|
|
||||||
2|aten.relu.default|relu|
|
2|aten.relu.default|relu|
|
||||||
2|aten.detach.default|relu|
|
2|aten.detach.default|relu|
|
||||||
2|aten.detach.default|relu|
|
|
||||||
3|aten.add.Tensor|add|
|
3|aten.add.Tensor|add|
|
||||||
4|aten.view.default|flatten|
|
4|aten.view.default|flatten|
|
||||||
5|aten.view.default|linear|
|
5|aten.view.default|linear|
|
||||||
@ -948,7 +947,6 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
|||||||
5|aten.view.default||linear
|
5|aten.view.default||linear
|
||||||
4|aten.view.default||flatten
|
4|aten.view.default||flatten
|
||||||
2|aten.detach.default||relu
|
2|aten.detach.default||relu
|
||||||
2|aten.detach.default||relu
|
|
||||||
2|aten.threshold_backward.default||relu
|
2|aten.threshold_backward.default||relu
|
||||||
1|aten.native_batch_norm_backward.default||batch_norm
|
1|aten.native_batch_norm_backward.default||batch_norm
|
||||||
0|aten.convolution_backward.default||conv2d
|
0|aten.convolution_backward.default||conv2d
|
||||||
|
@ -241,18 +241,16 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
|
|||||||
('call_function', 'getitem', {'compile_inductor': 0})
|
('call_function', 'getitem', {'compile_inductor': 0})
|
||||||
('call_function', 'getitem_1', {'compile_inductor': 0})
|
('call_function', 'getitem_1', {'compile_inductor': 0})
|
||||||
('call_function', 'detach_1', {'compile_inductor': 0})
|
('call_function', 'detach_1', {'compile_inductor': 0})
|
||||||
('call_function', 'detach_4', {'compile_inductor': 0})
|
('call_function', 'detach_3', {'compile_inductor': 0})""", # noqa: B950
|
||||||
('call_function', 'detach_5', {'compile_inductor': 0})""", # noqa: B950
|
|
||||||
)
|
)
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
str(bw_metadata),
|
str(bw_metadata),
|
||||||
"""\
|
"""\
|
||||||
('placeholder', 'getitem', {'compile_inductor': 0})
|
('placeholder', 'getitem', {'compile_inductor': 0})
|
||||||
('placeholder', 'detach_5', {'compile_inductor': 0})
|
('placeholder', 'detach_3', {'compile_inductor': 0})
|
||||||
('call_function', 'zeros', {'compile_inductor': 0})
|
('call_function', 'zeros', {'compile_inductor': 0})
|
||||||
('call_function', 'detach', {'compile_inductor': 0})
|
('call_function', 'detach', {'compile_inductor': 0})
|
||||||
('call_function', 'detach_2', {'compile_inductor': 0})
|
('call_function', 'detach_2', {'compile_inductor': 0})
|
||||||
('call_function', 'detach_3', {'compile_inductor': 0})
|
|
||||||
('get_attr', 'fw_graph0', {'compile_inductor': 0})
|
('get_attr', 'fw_graph0', {'compile_inductor': 0})
|
||||||
[]
|
[]
|
||||||
('get_attr', 'joint_graph0', {'compile_inductor': 0})
|
('get_attr', 'joint_graph0', {'compile_inductor': 0})
|
||||||
|
@ -684,11 +684,11 @@ class StructuredTraceTest(TestCase):
|
|||||||
{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
{"describe_storage": {"id": 16, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
{"describe_storage": {"id": 16, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||||
{"describe_tensor": {"id": 29, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
{"describe_tensor": {"id": 28, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||||
{"describe_source": {"describer_id": "ID", "id": 29, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
{"describe_source": {"describer_id": "ID", "id": 28, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||||
{"describe_storage": {"id": 17, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
{"describe_storage": {"id": 17, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||||
{"describe_tensor": {"id": 30, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 17, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
{"describe_tensor": {"id": 29, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 17, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||||
{"describe_source": {"describer_id": "ID", "id": 30, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
{"describe_source": {"describer_id": "ID", "id": 29, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
|
||||||
{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
|
@ -45,11 +45,9 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
|||||||
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
|
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
|
||||||
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
|
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
|
||||||
alias = torch.ops.aten.alias.default(_softmax)
|
alias = torch.ops.aten.alias.default(_softmax)
|
||||||
alias_1 = torch.ops.aten.alias.default(alias); alias = None
|
|
||||||
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
||||||
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
|
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
|
||||||
alias_2 = torch.ops.aten.alias.default(_log_softmax)
|
alias_1 = torch.ops.aten.alias.default(_log_softmax)
|
||||||
alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None
|
|
||||||
mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None
|
mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None
|
||||||
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
||||||
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
||||||
@ -59,17 +57,15 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
|||||||
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
||||||
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
||||||
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
||||||
alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None
|
alias_2 = torch.ops.aten.alias.default(alias_1); alias_1 = None
|
||||||
alias_5 = torch.ops.aten.alias.default(alias_4); alias_4 = None
|
exp = torch.ops.aten.exp.default(alias_2); alias_2 = None
|
||||||
exp = torch.ops.aten.exp.default(alias_5); alias_5 = None
|
|
||||||
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
|
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
|
||||||
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
|
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
|
||||||
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
|
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
|
||||||
alias_6 = torch.ops.aten.alias.default(alias_1); alias_1 = None
|
alias_3 = torch.ops.aten.alias.default(alias); alias = None
|
||||||
alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None
|
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_3); sub = None
|
||||||
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_7); sub = None
|
|
||||||
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
|
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
|
||||||
mul_4 = torch.ops.aten.mul.Tensor(alias_7, sum_3); alias_7 = sum_3 = None
|
mul_4 = torch.ops.aten.mul.Tensor(alias_3, sum_3); alias_3 = sum_3 = None
|
||||||
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
|
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
|
||||||
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
|
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
|
||||||
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
|
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
|
||||||
@ -91,11 +87,9 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
|||||||
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
|
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
|
||||||
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
|
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
|
||||||
alias = torch.ops.aten.alias.default(_softmax)
|
alias = torch.ops.aten.alias.default(_softmax)
|
||||||
alias_1 = torch.ops.aten.alias.default(alias); alias = None
|
|
||||||
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
||||||
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
|
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
|
||||||
alias_2 = torch.ops.aten.alias.default(_log_softmax)
|
alias_1 = torch.ops.aten.alias.default(_log_softmax)
|
||||||
alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None
|
|
||||||
mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None
|
mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None
|
||||||
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
||||||
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
||||||
@ -105,17 +99,15 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
|||||||
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
||||||
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
||||||
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
||||||
alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None
|
alias_2 = torch.ops.aten.alias.default(alias_1); alias_1 = None
|
||||||
alias_5 = torch.ops.aten.alias.default(alias_4); alias_4 = None
|
exp = torch.ops.aten.exp.default(alias_2); alias_2 = None
|
||||||
exp = torch.ops.aten.exp.default(alias_5); alias_5 = None
|
|
||||||
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
|
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
|
||||||
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
|
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
|
||||||
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
|
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
|
||||||
alias_6 = torch.ops.aten.alias.default(alias_1); alias_1 = None
|
alias_3 = torch.ops.aten.alias.default(alias); alias = None
|
||||||
alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None
|
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_3); sub = None
|
||||||
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_7); sub = None
|
|
||||||
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
|
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
|
||||||
mul_4 = torch.ops.aten.mul.Tensor(alias_7, sum_3); alias_7 = sum_3 = None
|
mul_4 = torch.ops.aten.mul.Tensor(alias_3, sum_3); alias_3 = sum_3 = None
|
||||||
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
|
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
|
||||||
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
|
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
|
||||||
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
|
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
|
||||||
|
@ -1229,9 +1229,7 @@ def forward(self, primals, tangents):
|
|||||||
t = torch.ops.aten.t.default(primals_1); primals_1 = None
|
t = torch.ops.aten.t.default(primals_1); primals_1 = None
|
||||||
addmm = torch.ops.aten.addmm.default(primals_2, primals_5, t); primals_2 = None
|
addmm = torch.ops.aten.addmm.default(primals_2, primals_5, t); primals_2 = None
|
||||||
relu = torch.ops.aten.relu.default(addmm); addmm = None
|
relu = torch.ops.aten.relu.default(addmm); addmm = None
|
||||||
detach_9 = torch.ops.aten.detach.default(relu)
|
detach_3 = torch.ops.aten.detach.default(relu)
|
||||||
detach_10 = torch.ops.aten.detach.default(detach_9); detach_9 = None
|
|
||||||
detach_11 = torch.ops.aten.detach.default(detach_10); detach_10 = None
|
|
||||||
t_1 = torch.ops.aten.t.default(primals_3); primals_3 = None
|
t_1 = torch.ops.aten.t.default(primals_3); primals_3 = None
|
||||||
addmm_1 = torch.ops.aten.addmm.default(primals_4, relu, t_1); primals_4 = None
|
addmm_1 = torch.ops.aten.addmm.default(primals_4, relu, t_1); primals_4 = None
|
||||||
t_2 = torch.ops.aten.t.default(t_1); t_1 = None
|
t_2 = torch.ops.aten.t.default(t_1); t_1 = None
|
||||||
@ -1242,9 +1240,8 @@ def forward(self, primals, tangents):
|
|||||||
sum_1 = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True); tangents_1 = None
|
sum_1 = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True); tangents_1 = None
|
||||||
view = torch.ops.aten.view.default(sum_1, [128]); sum_1 = None
|
view = torch.ops.aten.view.default(sum_1, [128]); sum_1 = None
|
||||||
t_5 = torch.ops.aten.t.default(t_4); t_4 = None
|
t_5 = torch.ops.aten.t.default(t_4); t_4 = None
|
||||||
detach_18 = torch.ops.aten.detach.default(detach_11); detach_11 = None
|
detach_6 = torch.ops.aten.detach.default(detach_3); detach_3 = None
|
||||||
detach_19 = torch.ops.aten.detach.default(detach_18); detach_18 = None
|
threshold_backward = torch.ops.aten.threshold_backward.default(mm, detach_6, 0); mm = detach_6 = None
|
||||||
threshold_backward = torch.ops.aten.threshold_backward.default(mm, detach_19, 0); mm = detach_19 = None
|
|
||||||
t_6 = torch.ops.aten.t.default(t); t = None
|
t_6 = torch.ops.aten.t.default(t); t = None
|
||||||
mm_2 = torch.ops.aten.mm.default(threshold_backward, t_6); t_6 = None
|
mm_2 = torch.ops.aten.mm.default(threshold_backward, t_6); t_6 = None
|
||||||
t_7 = torch.ops.aten.t.default(threshold_backward)
|
t_7 = torch.ops.aten.t.default(threshold_backward)
|
||||||
@ -10302,13 +10299,9 @@ graph():
|
|||||||
%x : [num_users=2] = placeholder[target=x]
|
%x : [num_users=2] = placeholder[target=x]
|
||||||
%ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False})
|
%ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False})
|
||||||
%detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {})
|
%detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {})
|
||||||
%detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach,), kwargs = {})
|
|
||||||
%detach_2 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_1,), kwargs = {})
|
|
||||||
%clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {})
|
%clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {})
|
||||||
%detach_3 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {})
|
%detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {})
|
||||||
%detach_4 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_3,), kwargs = {})
|
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach, %detach_1), kwargs = {})
|
||||||
%detach_5 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_4,), kwargs = {})
|
|
||||||
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach_2, %detach_5), kwargs = {})
|
|
||||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {})
|
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {})
|
||||||
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
|
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
|
||||||
return (mul_1,)""",
|
return (mul_1,)""",
|
||||||
|
@ -214,9 +214,7 @@ class inner_f(torch.nn.Module):
|
|||||||
where: "f32[2, 3, 4, 4]" = torch.ops.prims.where.default(le, 0.0, add_4); le = add_4 = None
|
where: "f32[2, 3, 4, 4]" = torch.ops.prims.where.default(le, 0.0, add_4); le = add_4 = None
|
||||||
view_of: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(where)
|
view_of: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(where)
|
||||||
view_of_1: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of); view_of = None
|
view_of_1: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of); view_of = None
|
||||||
view_of_2: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of_1); view_of_1 = None
|
le_1: "b8[2, 3, 4, 4]" = torch.ops.prims.le.default(view_of_1, 0.0); view_of_1 = None
|
||||||
view_of_3: "f32[2, 3, 4, 4]" = torch.ops.prims.view_of.default(view_of_2); view_of_2 = None
|
|
||||||
le_1: "b8[2, 3, 4, 4]" = torch.ops.prims.le.default(view_of_3, 0.0); view_of_3 = None
|
|
||||||
where_1: "f32[2, 3, 4, 4]" = torch.ops.prims.where.default(le_1, 0.0, tangents_1); le_1 = tangents_1 = None
|
where_1: "f32[2, 3, 4, 4]" = torch.ops.prims.where.default(le_1, 0.0, tangents_1); le_1 = tangents_1 = None
|
||||||
broadcast_in_dim_10: "f32[1, 3]" = torch.ops.prims.broadcast_in_dim.default(squeeze_2, [1, 3], [1]); squeeze_2 = None
|
broadcast_in_dim_10: "f32[1, 3]" = torch.ops.prims.broadcast_in_dim.default(squeeze_2, [1, 3], [1]); squeeze_2 = None
|
||||||
broadcast_in_dim_11: "f32[1, 3, 1]" = torch.ops.prims.broadcast_in_dim.default(broadcast_in_dim_10, [1, 3, 1], [0, 1]); broadcast_in_dim_10 = None
|
broadcast_in_dim_11: "f32[1, 3, 1]" = torch.ops.prims.broadcast_in_dim.default(broadcast_in_dim_10, [1, 3, 1], [0, 1]); broadcast_in_dim_10 = None
|
||||||
|
@ -2278,9 +2278,7 @@ def forward(self, primals_1):
|
|||||||
view = torch.ops.aten.view.default(mul, [-1])
|
view = torch.ops.aten.view.default(mul, [-1])
|
||||||
select = torch.ops.aten.select.int(mul, 0, 0)
|
select = torch.ops.aten.select.int(mul, 0, 0)
|
||||||
detach = torch.ops.aten.detach.default(select); select = None
|
detach = torch.ops.aten.detach.default(select); select = None
|
||||||
detach_1 = torch.ops.aten.detach.default(detach); detach = None
|
return (view, mul, detach)""",
|
||||||
detach_2 = torch.ops.aten.detach.default(detach_1); detach_1 = None
|
|
||||||
return (view, mul, detach_2)""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_output_aliases_intermediate_inplace_view(self):
|
def test_output_aliases_intermediate_inplace_view(self):
|
||||||
@ -5138,23 +5136,12 @@ class <lambda>(torch.nn.Module):
|
|||||||
relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None
|
relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None
|
||||||
detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); detach = None
|
detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); detach = None
|
||||||
detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu)
|
detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu)
|
||||||
detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
|
|
||||||
detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_2); detach_2 = None
|
|
||||||
detach_4: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_3); detach_3 = None
|
|
||||||
sum_1: "f32[]" = torch.ops.aten.sum.default(relu)
|
sum_1: "f32[]" = torch.ops.aten.sum.default(relu)
|
||||||
detach_5: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None
|
detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None
|
||||||
detach_6: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_5); detach_5 = None
|
|
||||||
detach_7: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_6); detach_6 = None
|
|
||||||
detach_8: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_7); detach_7 = None
|
|
||||||
detach_9: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_8); detach_8 = None
|
|
||||||
detach_10: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_9); detach_9 = None
|
|
||||||
ones_like: "f32[]" = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
|
ones_like: "f32[]" = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
|
||||||
expand: "f32[1, 3, 3, 3]" = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]); ones_like = None
|
expand: "f32[1, 3, 3, 3]" = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]); ones_like = None
|
||||||
detach_11: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_4); detach_4 = None
|
detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
|
||||||
detach_12: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_11); detach_11 = None
|
threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_3, 0); expand = detach_3 = None
|
||||||
detach_13: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_12); detach_12 = None
|
|
||||||
detach_14: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_13); detach_13 = None
|
|
||||||
threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_14, 0); expand = detach_14 = None
|
|
||||||
native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None
|
native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None
|
||||||
getitem_5: "f32[1, 3, 3, 3]" = native_batch_norm_backward[0]
|
getitem_5: "f32[1, 3, 3, 3]" = native_batch_norm_backward[0]
|
||||||
getitem_6: "f32[3]" = native_batch_norm_backward[1]
|
getitem_6: "f32[3]" = native_batch_norm_backward[1]
|
||||||
@ -5163,7 +5150,7 @@ class <lambda>(torch.nn.Module):
|
|||||||
getitem_8 = convolution_backward[0]; getitem_8 = None
|
getitem_8 = convolution_backward[0]; getitem_8 = None
|
||||||
getitem_9: "f32[3, 1, 1, 1]" = convolution_backward[1]
|
getitem_9: "f32[3, 1, 1, 1]" = convolution_backward[1]
|
||||||
getitem_10: "f32[3]" = convolution_backward[2]; convolution_backward = None
|
getitem_10: "f32[3]" = convolution_backward[2]; convolution_backward = None
|
||||||
return (getitem_3, getitem_4, add, sum_1, detach_10, getitem_9, getitem_10, getitem_6, getitem_7)
|
return (getitem_3, getitem_4, add, sum_1, detach_2, getitem_9, getitem_10, getitem_6, getitem_7)
|
||||||
""", # noqa: B950
|
""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -5231,14 +5218,12 @@ class <lambda>(torch.nn.Module):
|
|||||||
relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None
|
relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None
|
||||||
sum_1: "f32[]" = torch.ops.aten.sum.default(relu)
|
sum_1: "f32[]" = torch.ops.aten.sum.default(relu)
|
||||||
detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None
|
detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None
|
||||||
detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach); detach = None
|
|
||||||
detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
|
|
||||||
return (
|
return (
|
||||||
getitem_3, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=4))
|
getitem_3, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=4))
|
||||||
getitem_4, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=5))
|
getitem_4, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=5))
|
||||||
add, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=6))
|
add, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=6))
|
||||||
sum_1, # PlainAOTOutput(idx=0)
|
sum_1, # PlainAOTOutput(idx=0)
|
||||||
detach_2, # PlainAOTOutput(idx=1)
|
detach, # PlainAOTOutput(idx=1)
|
||||||
)
|
)
|
||||||
""", # noqa: B950
|
""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
@ -1174,12 +1174,10 @@ class TestMemoryProfilerE2E(TestCase):
|
|||||||
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 20 (AUTOGRAD_DETAIL)
|
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 20 (AUTOGRAD_DETAIL)
|
||||||
aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT)
|
aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT)
|
||||||
aten::view 21 (GRADIENT) -> 21 (GRADIENT)
|
aten::view 21 (GRADIENT) -> 21 (GRADIENT)
|
||||||
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
|
|
||||||
aten::detach 21 (GRADIENT) -> ???
|
aten::detach 21 (GRADIENT) -> ???
|
||||||
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
|
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
|
||||||
aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT)
|
aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT)
|
||||||
aten::view 23 (GRADIENT) -> 23 (GRADIENT)
|
aten::view 23 (GRADIENT) -> 23 (GRADIENT)
|
||||||
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
|
|
||||||
aten::detach 23 (GRADIENT) -> ???""",
|
aten::detach 23 (GRADIENT) -> ???""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1227,12 +1225,10 @@ class TestMemoryProfilerE2E(TestCase):
|
|||||||
aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT)
|
aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT)
|
||||||
aten::view 21 (GRADIENT) -> 21 (GRADIENT)
|
aten::view 21 (GRADIENT) -> 21 (GRADIENT)
|
||||||
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
|
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
|
||||||
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
|
|
||||||
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
|
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
|
||||||
aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT)
|
aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT)
|
||||||
aten::view 23 (GRADIENT) -> 23 (GRADIENT)
|
aten::view 23 (GRADIENT) -> 23 (GRADIENT)
|
||||||
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
|
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
|
||||||
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
|
|
||||||
|
|
||||||
-- Optimizer --------------------------------------------------------------------------------------------
|
-- Optimizer --------------------------------------------------------------------------------------------
|
||||||
aten::add_.Tensor 3 (PARAMETER), 23 (GRADIENT) -> 3 (PARAMETER)
|
aten::add_.Tensor 3 (PARAMETER), 23 (GRADIENT) -> 3 (PARAMETER)
|
||||||
@ -1277,10 +1273,8 @@ class TestMemoryProfilerE2E(TestCase):
|
|||||||
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
|
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
|
||||||
aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT)
|
aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT)
|
||||||
aten::view 9 (GRADIENT) -> 9 (GRADIENT)
|
aten::view 9 (GRADIENT) -> 9 (GRADIENT)
|
||||||
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
|
|
||||||
aten::detach 9 (GRADIENT) -> ???
|
aten::detach 9 (GRADIENT) -> ???
|
||||||
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
|
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
|
||||||
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
|
|
||||||
aten::detach 7 (GRADIENT) -> ???""",
|
aten::detach 7 (GRADIENT) -> ???""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1318,18 +1312,14 @@ class TestMemoryProfilerE2E(TestCase):
|
|||||||
aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT)
|
aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT)
|
||||||
aten::view 9 (GRADIENT) -> 9 (GRADIENT)
|
aten::view 9 (GRADIENT) -> 9 (GRADIENT)
|
||||||
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
|
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
|
||||||
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
|
|
||||||
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
|
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
|
||||||
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
|
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
|
||||||
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
|
|
||||||
|
|
||||||
-- Optimizer --------------------------------------------------------------------------------------------
|
-- Optimizer --------------------------------------------------------------------------------------------
|
||||||
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
|
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
|
||||||
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
|
|
||||||
aten::clone 7 (GRADIENT) -> 10 (OPTIMIZER_STATE)
|
aten::clone 7 (GRADIENT) -> 10 (OPTIMIZER_STATE)
|
||||||
aten::add_.Tensor 2 (PARAMETER), 10 (OPTIMIZER_STATE) -> 2 (PARAMETER)
|
aten::add_.Tensor 2 (PARAMETER), 10 (OPTIMIZER_STATE) -> 2 (PARAMETER)
|
||||||
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
|
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
|
||||||
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
|
|
||||||
aten::clone 9 (GRADIENT) -> 11 (OPTIMIZER_STATE)
|
aten::clone 9 (GRADIENT) -> 11 (OPTIMIZER_STATE)
|
||||||
aten::add_.Tensor 3 (PARAMETER), 11 (OPTIMIZER_STATE) -> 3 (PARAMETER)""",
|
aten::add_.Tensor 3 (PARAMETER), 11 (OPTIMIZER_STATE) -> 3 (PARAMETER)""",
|
||||||
)
|
)
|
||||||
@ -1414,7 +1404,6 @@ class TestMemoryProfilerE2E(TestCase):
|
|||||||
aten::t 7 (PARAMETER) -> 7 (PARAMETER)
|
aten::t 7 (PARAMETER) -> 7 (PARAMETER)
|
||||||
aten::mm 25 (AUTOGRAD_DETAIL), 7 (PARAMETER) -> 27 (AUTOGRAD_DETAIL)
|
aten::mm 25 (AUTOGRAD_DETAIL), 7 (PARAMETER) -> 27 (AUTOGRAD_DETAIL)
|
||||||
aten::t 26 (GRADIENT) -> 26 (GRADIENT)
|
aten::t 26 (GRADIENT) -> 26 (GRADIENT)
|
||||||
aten::detach 26 (GRADIENT) -> 26 (GRADIENT)
|
|
||||||
aten::detach 26 (GRADIENT) -> ???
|
aten::detach 26 (GRADIENT) -> ???
|
||||||
aten::detach 6 (ACTIVATION) -> 6 (ACTIVATION)
|
aten::detach 6 (ACTIVATION) -> 6 (ACTIVATION)
|
||||||
aten::threshold_backward 27 (AUTOGRAD_DETAIL), 6 (ACTIVATION) -> 28 (AUTOGRAD_DETAIL)
|
aten::threshold_backward 27 (AUTOGRAD_DETAIL), 6 (ACTIVATION) -> 28 (AUTOGRAD_DETAIL)
|
||||||
@ -1423,10 +1412,8 @@ class TestMemoryProfilerE2E(TestCase):
|
|||||||
aten::t 29 (GRADIENT) -> 29 (GRADIENT)
|
aten::t 29 (GRADIENT) -> 29 (GRADIENT)
|
||||||
aten::sum.dim_IntList 28 (AUTOGRAD_DETAIL) -> 30 (GRADIENT)
|
aten::sum.dim_IntList 28 (AUTOGRAD_DETAIL) -> 30 (GRADIENT)
|
||||||
aten::view 30 (GRADIENT) -> 30 (GRADIENT)
|
aten::view 30 (GRADIENT) -> 30 (GRADIENT)
|
||||||
aten::detach 30 (GRADIENT) -> 30 (GRADIENT)
|
|
||||||
aten::detach 30 (GRADIENT) -> ???
|
aten::detach 30 (GRADIENT) -> ???
|
||||||
aten::t 29 (GRADIENT) -> 29 (GRADIENT)
|
aten::t 29 (GRADIENT) -> 29 (GRADIENT)
|
||||||
aten::detach 29 (GRADIENT) -> 29 (GRADIENT)
|
|
||||||
aten::detach 29 (GRADIENT) -> ???""",
|
aten::detach 29 (GRADIENT) -> ???""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -5050,7 +5050,6 @@ Running aten.expand.default from within SumBackward0
|
|||||||
Running aten.div.Tensor from within DivBackward0
|
Running aten.div.Tensor from within DivBackward0
|
||||||
Running aten.mul.Tensor from within MulBackward0
|
Running aten.mul.Tensor from within MulBackward0
|
||||||
Running aten.detach.default from within AccumulateGrad
|
Running aten.detach.default from within AccumulateGrad
|
||||||
Running aten.detach.default from within AccumulateGrad
|
|
||||||
Done""",
|
Done""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -7323,9 +7322,7 @@ for shape in [(1,), ()]:
|
|||||||
lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn
|
lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn
|
||||||
)
|
)
|
||||||
out.backward()
|
out.backward()
|
||||||
self.assertEqual(
|
self.assertEqual(verbose_mode.operators, ["exp.default", "detach.default"])
|
||||||
verbose_mode.operators, ["exp.default", "detach.default", "detach.default"]
|
|
||||||
)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
Exception, "only supported when use_reentrant=False"
|
Exception, "only supported when use_reentrant=False"
|
||||||
|
@ -850,7 +850,7 @@ $1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""",
|
|||||||
lambda: A(torch.zeros(1)).detach(),
|
lambda: A(torch.zeros(1)).detach(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_detach_appears_twice_when_called_once(self) -> None:
|
def test_detach_appears_once_when_called_once(self) -> None:
|
||||||
with capture_logs() as logs:
|
with capture_logs() as logs:
|
||||||
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
|
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
|
||||||
log_input("x", x)
|
log_input("x", x)
|
||||||
@ -863,8 +863,7 @@ $1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""",
|
|||||||
"\n".join(logs),
|
"\n".join(logs),
|
||||||
"""\
|
"""\
|
||||||
$0: f32[1] = input('x')
|
$0: f32[1] = input('x')
|
||||||
$1: f32[1] = torch._ops.aten.detach.default($0)
|
$1: f32[1] = torch._ops.aten.detach.default($0)""",
|
||||||
$2: f32[1] = torch._ops.aten.detach.default($1)""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_storage(self) -> None:
|
def test_storage(self) -> None:
|
||||||
|
@ -453,20 +453,18 @@ static Tensor detach(c10::DispatchKeySet ks, const Tensor& self) {
|
|||||||
return at::_ops::detach::redispatch(
|
return at::_ops::detach::redispatch(
|
||||||
ks & c10::after_ADInplaceOrView_keyset, self);
|
ks & c10::after_ADInplaceOrView_keyset, self);
|
||||||
})();
|
})();
|
||||||
// NB: we can't make detach() a normal view operator because the codegen
|
// NB: we can't make detach() a normal view operator because the
|
||||||
// generates allow_tensor_metadata_change = True for them. In the future we
|
// codegen generates allow_tensor_metadata_change = True (and leaves
|
||||||
// should have an option for this in the codegen.
|
// is_fresh_tensor to the default setting of False) for them. In the
|
||||||
auto result = as_view(
|
// future we should have an option for this in the codegen.
|
||||||
/* base */ self,
|
if (self.is_inference()) {
|
||||||
/* output */ out,
|
return out;
|
||||||
/* is_bw_differentiable */ false,
|
}
|
||||||
/* is_fw_differentiable */ false,
|
return ::torch::autograd::make_variable_non_differentiable_view(
|
||||||
/* view_func */ nullptr,
|
self,
|
||||||
/* rev_view_func */ nullptr,
|
out,
|
||||||
/* creation_meta */ CreationMeta::DEFAULT,
|
/* allow_tensor_metadata_change */ false,
|
||||||
/*allow_tensor_metadata_change=*/false);
|
/* is_fresh_tensor */ true);
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Tensor _fw_primal(
|
static Tensor _fw_primal(
|
||||||
|
@ -849,11 +849,20 @@ inline Variable make_variable_differentiable_view(
|
|||||||
inline Variable make_variable_non_differentiable_view(
|
inline Variable make_variable_non_differentiable_view(
|
||||||
const Variable& base,
|
const Variable& base,
|
||||||
const at::Tensor& data,
|
const at::Tensor& data,
|
||||||
bool allow_tensor_metadata_change = true) {
|
bool allow_tensor_metadata_change = true,
|
||||||
|
bool is_fresh_tensor = false) {
|
||||||
if (data.defined()) {
|
if (data.defined()) {
|
||||||
// Currently all of non-differentiable view ops(detach/_indices/_values)
|
// If we already allocated a new tensor, no need to
|
||||||
// share the same TensorImpl as their base Tensor. Thus a new TensorImpl
|
// shallow_copy_and_detach here. (See #163671 history; we tried to
|
||||||
// allocation here is required.
|
// fan out to _indices and _values and ran into a SparseTensorImpl
|
||||||
|
// can of worms.)
|
||||||
|
if (is_fresh_tensor) {
|
||||||
|
auto* data_impl = data.unsafeGetTensorImpl();
|
||||||
|
data_impl->set_version_counter(impl::version_counter(base));
|
||||||
|
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||||
|
data_impl->set_autograd_meta(nullptr);
|
||||||
|
return data;
|
||||||
|
}
|
||||||
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
|
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
|
||||||
/*version_counter=*/impl::version_counter(base),
|
/*version_counter=*/impl::version_counter(base),
|
||||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||||
|
Reference in New Issue
Block a user