Laith Sakka
2025-10-08 09:27:12 -07:00
committed by PyTorch MergeBot
parent 6a7f5c0d21
commit 17c7170ca6
2 changed files with 20 additions and 1 deletions

View File

@ -4222,6 +4222,21 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
eager_result = func(x, torch.tensor([5]))
self.assertEqual(cnt.frame_count, 2)
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_item(self):
def func():
_x_ms = torch.tensor([True, False], dtype=torch.int64)
_mask_ms = torch.zeros_like(_x_ms, dtype=torch.bool)
_mask_ms[:1] = True
var_node_2 = torch.masked_select(_x_ms, _mask_ms)
var_node_0 = var_node_2.item()
return var_node_0
result_original = func()
compiled_program = torch.compile(func, fullgraph=True, dynamic=True)
result_compiled = compiled_program()
self.assertEqual(result_original, result_compiled)
instantiate_parametrized_tests(TestUnbacked)