[dynamo, guards] Better error messages when generated guard fails on the same frame (#165242)

Not sure what exactly we want to have in the message, but that's easy to adjust. I tried to find a reliable test to reproduce this message (happens only when a guard fails right after it's created), but I ended up mocking a `guard_manager.check` function to return `False` to trigger this behavior. I think that's fine, because any other case that we pick (like datetime.now()), we want to patch one day anyway, so every time we make the next patch, will need to chase for another repro test

@williamwen42

Fixes #164990

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165242
Approved by: https://github.com/williamwen42
This commit is contained in:
jmaczan
2025-10-16 01:05:28 +00:00
committed by PyTorch MergeBot
parent c2bd41ac9f
commit 003dd13073
2 changed files with 39 additions and 1 deletions

View File

@ -7369,6 +7369,41 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
)
self.assertEqual(explain_output.break_reasons[0].reason, expected_msg)
# https://github.com/pytorch/pytorch/issues/164990
def test_guard_same_frame_fail_message(self):
import torch._dynamo.guards as g
# deterministically fail check on the same frame to verify error message correctness
# the other example of fail might be datetime.now() until patched - see issue #164990
compile_check_fn = g.CheckFunctionManager.compile_check_fn
def wrapper(self, builder, sorted_guards, guard_fail_fn):
compile_check_fn(self, builder, sorted_guards, guard_fail_fn)
def check(x):
return False
self.guard_manager.check = check
with mock.patch.object(g.CheckFunctionManager, "compile_check_fn", new=wrapper):
class Model(nn.Module):
def forward(self, x):
return x + 1
model = Model()
x = torch.randn(5)
with self.assertRaises(AssertionError) as e:
torch.compile(model)(x)
msg = str(e.exception)
self.assertIn(
"Guard failed on the same frame it was created. This is a bug - please create an issue."
"Guard fail reason: ",
msg,
)
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
def test_sub_alpha_scalar_repro(self, device):

View File

@ -3605,7 +3605,10 @@ class CheckFunctionManager:
output_graph.local_scope,
CompileContext.current_compile_id(),
)
raise AssertionError(f"Guard check failed: {reasons}")
raise AssertionError(
"Guard failed on the same frame it was created. This is a bug - please create an issue."
f"Guard fail reason: {reasons}"
)
if guard_manager_testing_hook_fn is not None:
guard_manager_testing_hook_fn(