mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix Avoid DDE in item numel check (#164934)
address https://github.com/pytorch/pytorch/issues/164725 and https://github.com/pytorch/pytorch/issues/164704 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164934 Approved by: https://github.com/ezyang, https://github.com/aorenste, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
263db92563
commit
a9a9a3438a
@ -4222,6 +4222,21 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
||||
eager_result = func(x, torch.tensor([5]))
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_unbacked_item(self):
|
||||
def func():
|
||||
_x_ms = torch.tensor([True, False], dtype=torch.int64)
|
||||
_mask_ms = torch.zeros_like(_x_ms, dtype=torch.bool)
|
||||
_mask_ms[:1] = True
|
||||
var_node_2 = torch.masked_select(_x_ms, _mask_ms)
|
||||
var_node_0 = var_node_2.item()
|
||||
return var_node_0
|
||||
|
||||
result_original = func()
|
||||
compiled_program = torch.compile(func, fullgraph=True, dynamic=True)
|
||||
result_compiled = compiled_program()
|
||||
self.assertEqual(result_original, result_compiled)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestUnbacked)
|
||||
|
||||
|
Reference in New Issue
Block a user