diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 5c3240458f5e..3dc90a1b9d8e 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -1363,6 +1363,9 @@ s1 > 3""", s = SubTensor(torch.randn(3, 10)) f(s) + # Guard validation upsets the guard + # https://github.com/pytorch/pytorch/issues/129936 + @unittest.expectedFailure def test_recompile_with_symbool_inputs(self): def f(pred: bool): if pred: diff --git a/test/dynamo_expected_failures/TestAOTModuleSimplified.test_aot_module_simplified_dynamic b/test/dynamo_expected_failures/TestAOTModuleSimplified.test_aot_module_simplified_dynamic new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/export/test_export.py b/test/export/test_export.py index d61cbe2d3107..f1e4bd1b0da9 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3601,6 +3601,8 @@ def forward(self, x): ): torch.export.export(exported_v2.module(), (torch.randn(2, 2),)) + # https://github.com/pytorch/pytorch/issues/129939 + @testing.expectedFailureNonStrict def test_export_cond(self): class A(torch.nn.Module): def __init__(self): @@ -4976,6 +4978,9 @@ graph(): ) ) + # Guard validation upsets the guard + # https://github.com/pytorch/pytorch/issues/129939 + @unittest.expectedFailure def test_cond_with_module_stack_export_with(self): class Bar(torch.nn.Module): def __init__(self): diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index f218a2037792..ddafe400af11 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2346,6 +2346,7 @@ known_failing_tests = { "test_grad_nonleaf_register_hook", # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors) "test_unpack_hooks_exec_count", # pack/unpack saved tensor hooks firing more than once "test_scalar_grad_mixed_device", # Fake Tensors aren't propagating device properly for 0-dim grads + "test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938 } if not HAS_CUDA: diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 04aa3e8ba030..58e5a1ec6bb3 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2103,6 +2103,7 @@ class CheckFunctionManager: guard.create(builder) self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn) + # Keep track of weak references of objects with ID_MATCH guard. This # info is stored alongside optimized_code and check_fn and is used to # limit the number of cache entries with same ID_MATCH'd object. @@ -2123,6 +2124,18 @@ class CheckFunctionManager: self.guard_manager.id_matched_objs = builder.id_matched_objs self.check_fn = self.guard_manager + # Check that the guard returns True. False means that we will always + # recompile. + # TODO(anijain2305, ydwu4) - Skipping export because of following test + # python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs + if not output_graph.export: + if not self.guard_manager.check(output_graph.local_scope): + reasons = get_guard_fail_reason_helper( + self.guard_manager, # type: ignore[arg-type] + output_graph.local_scope, + ) + raise AssertionError(f"Guard check failed: {reasons}") + # NB - We have to very careful of cleaning up here. Because of the # invalidate function, we can create a weakref finalizer that keeps # `self` alive for very long. Sometimes by mistake, we can run @@ -2456,9 +2469,8 @@ def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope): return [f"Duplicate tensors found: {reason}"] -def get_guard_fail_reason( +def get_guard_fail_reason_helper( guard_fn: GuardFn, - code: types.CodeType, f_locals: Dict[str, object], ) -> str: """ @@ -2525,6 +2537,15 @@ def get_guard_fail_reason( break reason_str = "\n".join(reasons) + return reason_str + + +def get_guard_fail_reason( + guard_fn: GuardFn, + code: types.CodeType, + f_locals: Dict[str, object], +) -> str: + reason_str = get_guard_fail_reason_helper(guard_fn, f_locals) guard_failures[orig_code_map[code]].append(reason_str) try: