Fix flaky Dynamo export tests (#96488)

Planning to do a full writeup later. The short story is, sometimes the following chain of events happens:

1. We turn on Dynamo's custom frame handler
2. GC triggers (and all of the finalizers run under Dynamo)
3. GC hits a GeneratorExit frame
4. You end up in the custom frame handler with throw_flag == TRUE and PyErr_Occurred() != NULL

If this happens and we blindly call into other Python functions (like the Python callback), the executed Python code will immediately raise an exception (because there's already an ambient exception set.) This is very, very confusing. The fix is to defer to the regular handler when throw_flag is TRUE.

I triggered this locally with

```
PYTHONUNBUFFERED=1 pytest test/dynamo/test_dynamic_shapes.py   -k 'Unspec and export and not dupes and not reorder' -v -x -s
```

But I also have some tests which trigger the problem synthetically.

Fixes https://github.com/pytorch/pytorch/issues/93781

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96488
Approved by: https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2023-03-10 08:09:57 -08:00
committed by PyTorch MergeBot
parent 7fcf8b1829
commit 80ce1a934e
3 changed files with 129 additions and 0 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: dynamo"]
import collections
import contextlib
import copy
import inspect
import itertools
@ -2239,6 +2240,81 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 2)
def test_exception_in_dynamo_handling(self):
hit_handler = False
# See https://github.com/pytorch/pytorch/pull/96488
@contextlib.contextmanager
def ctx():
try:
yield
except RuntimeError:
nonlocal hit_handler
hit_handler = True
@torch._dynamo.optimize("eager")
def f():
with ctx():
h()
def h():
raise RuntimeError("boof")
# Should not error
f()
self.assertTrue(hit_handler)
def test_generator_dealloc(self):
# See https://github.com/pytorch/pytorch/pull/96488
#
# NB: yes, [(...)] is intentional, this is a list containing a
# generator
generator_box = [(x for x in [1, 2, 3])]
counter = torch._dynamo.testing.CompileCounter()
def g(x):
return x + 2
# TODO: This test is pretty delicate. To test if it's actually doing
# anything, rebuild eval_frame.c with '#define TORCHDYNAMO_DEBUG 1'
# and then look at the logs for:
#
# TRACE[_custom_eval_frame:650] begin <genexpr> test_repros.py 2276 -1 0 0
# TRACE[_custom_eval_frame:664] throw <genexpr>
#
# This means we're actually hitting the relevant codepath
# NB: Make sure we don't actually Dynamo this frame; if we do Dynamo
# this frame, Dynamo actually DOES understand list.clear and will
# arrange for the generator deallocation to happen when the eval frame
# handler is disabled, which will prevent the bug from happening (we
# specifically want to trigger the generator deallocation WHILE the
# dynamo eval frame handler is active), as that will cause the
# generator to become exhausted and trigger the throw_flag == TRUE
# case.
@torch._dynamo.skip
def f(x):
generator_box.clear()
return g(x)
self.assertNoUnraisable(
lambda: torch._dynamo.optimize(counter)(f)(torch.randn(3))
)
# Make sure the x + 2 is captured (a previous incorrect implementation
# of this fix would have disabled the eval frame callback, which means
# g wouldn't get traced
self.assertEqual(counter.op_count, 1)
def test_error_return_without_exception_set(self):
# https://github.com/pytorch/pytorch/issues/93781
@torch.compile
def f():
_generator_type = type((_ for _ in ()))
self.assertNoUnraisable(f)
@skip_if_pytest
@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", True)
def test_rewrite_assert_with_msg(self):

View File

@ -646,6 +646,32 @@ static PyObject* _custom_eval_frame(
frame->f_lasti,
frame->f_iblock,
frame->f_executing);
if (throw_flag) {
// When unwinding generators, eval frame is called with throw_flag ==
// true. Frame evaluation is supposed to continue unwinding by propagating
// the exception. Dynamo doesn't really know how to do this, nor does it
// really want to do this, because there's unlikely any code to capture
// (you're going to immediately quit out of the frame, perhaps running
// some unwinding logic along the way). So we just run the default
// handler in this case.
//
// NB: A previous version of this patch returned NULL. This is wrong,
// because returning NULL is *different* from unwinding an exception.
// In particular, you will not execute things like context manager
// __exit__ if you just return NULL.
//
// NB: It's /conceivable/ that you might want to actually still call the
// Dynamo callback when throw_flag == TRUE, to give Dynamo a chance to
// do any stack unwinding code. But this is not really useful because
// (1) Dynamo doesn't actually know how to do stack unwinding, so it would
// immediately skip the frame, and (2) even if it did, this would only
// be profitable if there was tensor code in the unwinding code. Seems
// unlikely.
DEBUG_TRACE("throw %s", name(frame));
return eval_frame_default(tstate, frame, throw_flag);
}
CacheEntry* extra = get_extra(frame->f_code);
if (extra == SKIP_CODE || (callback == Py_False && extra == NULL)) {
DEBUG_TRACE("skip %s", name(frame));
@ -715,6 +741,10 @@ static PyObject* _custom_eval_frame(
// internal exception, returning here will leak the exception into user code
// this is useful for debugging -- but we dont want it to happen outside of
// testing
// NB: we intentionally DO NOT re-enable custom behavior to prevent
// cascading failure from internal exceptions. The upshot is if
// Dynamo barfs, that's it for Dynamo, even if you catch the exception
// inside the torch.compile block we won't try to Dynamo anything else.
return NULL;
} else if (result != Py_None) {
DEBUG_TRACE("create cache %s", name(frame));

View File

@ -3070,6 +3070,29 @@ class TestCase(expecttest.TestCase):
else:
return super().assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)
# Verifies that no unraisable exceptions are raised by callable. Unlike regular
# exceptions, these do not actually propagate to the caller and are
# suppressed. We must test for them specially.
def assertNoUnraisable(self, callable, *args, **kwargs):
raised = None
def record_unraisable(unraisable):
nonlocal raised
raised = unraisable
# Disable GC when running the callable to prevent spurious flakiness
# from unlucky GCs inside the callable
prev = gc.isenabled()
gc.disable()
try:
with unittest.mock.patch("sys.unraisablehook", record_unraisable):
callable(*args, **kwargs)
finally:
if prev:
gc.enable()
self.assertIsNone(raised)
# TODO: Support context manager interface
# NB: The kwargs forwarding to callable robs the 'subname' parameter.
# If you need it, manually apply your callable in a lambda instead.