mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix flaky Dynamo export tests (#96488)
Planning to do a full writeup later. The short story is, sometimes the following chain of events happens: 1. We turn on Dynamo's custom frame handler 2. GC triggers (and all of the finalizers run under Dynamo) 3. GC hits a GeneratorExit frame 4. You end up in the custom frame handler with throw_flag == TRUE and PyErr_Occurred() != NULL If this happens and we blindly call into other Python functions (like the Python callback), the executed Python code will immediately raise an exception (because there's already an ambient exception set.) This is very, very confusing. The fix is to defer to the regular handler when throw_flag is TRUE. I triggered this locally with ``` PYTHONUNBUFFERED=1 pytest test/dynamo/test_dynamic_shapes.py -k 'Unspec and export and not dupes and not reorder' -v -x -s ``` But I also have some tests which trigger the problem synthetically. Fixes https://github.com/pytorch/pytorch/issues/93781 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/96488 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
7fcf8b1829
commit
80ce1a934e
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
import inspect
|
||||
import itertools
|
||||
@ -2239,6 +2240,81 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
self.assertEqual(cnt.op_count, 2)
|
||||
|
||||
def test_exception_in_dynamo_handling(self):
|
||||
hit_handler = False
|
||||
|
||||
# See https://github.com/pytorch/pytorch/pull/96488
|
||||
@contextlib.contextmanager
|
||||
def ctx():
|
||||
try:
|
||||
yield
|
||||
except RuntimeError:
|
||||
nonlocal hit_handler
|
||||
hit_handler = True
|
||||
|
||||
@torch._dynamo.optimize("eager")
|
||||
def f():
|
||||
with ctx():
|
||||
h()
|
||||
|
||||
def h():
|
||||
raise RuntimeError("boof")
|
||||
|
||||
# Should not error
|
||||
f()
|
||||
self.assertTrue(hit_handler)
|
||||
|
||||
def test_generator_dealloc(self):
|
||||
# See https://github.com/pytorch/pytorch/pull/96488
|
||||
#
|
||||
# NB: yes, [(...)] is intentional, this is a list containing a
|
||||
# generator
|
||||
generator_box = [(x for x in [1, 2, 3])]
|
||||
|
||||
counter = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
def g(x):
|
||||
return x + 2
|
||||
|
||||
# TODO: This test is pretty delicate. To test if it's actually doing
|
||||
# anything, rebuild eval_frame.c with '#define TORCHDYNAMO_DEBUG 1'
|
||||
# and then look at the logs for:
|
||||
#
|
||||
# TRACE[_custom_eval_frame:650] begin <genexpr> test_repros.py 2276 -1 0 0
|
||||
# TRACE[_custom_eval_frame:664] throw <genexpr>
|
||||
#
|
||||
# This means we're actually hitting the relevant codepath
|
||||
|
||||
# NB: Make sure we don't actually Dynamo this frame; if we do Dynamo
|
||||
# this frame, Dynamo actually DOES understand list.clear and will
|
||||
# arrange for the generator deallocation to happen when the eval frame
|
||||
# handler is disabled, which will prevent the bug from happening (we
|
||||
# specifically want to trigger the generator deallocation WHILE the
|
||||
# dynamo eval frame handler is active), as that will cause the
|
||||
# generator to become exhausted and trigger the throw_flag == TRUE
|
||||
# case.
|
||||
@torch._dynamo.skip
|
||||
def f(x):
|
||||
generator_box.clear()
|
||||
return g(x)
|
||||
|
||||
self.assertNoUnraisable(
|
||||
lambda: torch._dynamo.optimize(counter)(f)(torch.randn(3))
|
||||
)
|
||||
|
||||
# Make sure the x + 2 is captured (a previous incorrect implementation
|
||||
# of this fix would have disabled the eval frame callback, which means
|
||||
# g wouldn't get traced
|
||||
self.assertEqual(counter.op_count, 1)
|
||||
|
||||
def test_error_return_without_exception_set(self):
|
||||
# https://github.com/pytorch/pytorch/issues/93781
|
||||
@torch.compile
|
||||
def f():
|
||||
_generator_type = type((_ for _ in ()))
|
||||
|
||||
self.assertNoUnraisable(f)
|
||||
|
||||
@skip_if_pytest
|
||||
@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", True)
|
||||
def test_rewrite_assert_with_msg(self):
|
||||
|
@ -646,6 +646,32 @@ static PyObject* _custom_eval_frame(
|
||||
frame->f_lasti,
|
||||
frame->f_iblock,
|
||||
frame->f_executing);
|
||||
|
||||
if (throw_flag) {
|
||||
// When unwinding generators, eval frame is called with throw_flag ==
|
||||
// true. Frame evaluation is supposed to continue unwinding by propagating
|
||||
// the exception. Dynamo doesn't really know how to do this, nor does it
|
||||
// really want to do this, because there's unlikely any code to capture
|
||||
// (you're going to immediately quit out of the frame, perhaps running
|
||||
// some unwinding logic along the way). So we just run the default
|
||||
// handler in this case.
|
||||
//
|
||||
// NB: A previous version of this patch returned NULL. This is wrong,
|
||||
// because returning NULL is *different* from unwinding an exception.
|
||||
// In particular, you will not execute things like context manager
|
||||
// __exit__ if you just return NULL.
|
||||
//
|
||||
// NB: It's /conceivable/ that you might want to actually still call the
|
||||
// Dynamo callback when throw_flag == TRUE, to give Dynamo a chance to
|
||||
// do any stack unwinding code. But this is not really useful because
|
||||
// (1) Dynamo doesn't actually know how to do stack unwinding, so it would
|
||||
// immediately skip the frame, and (2) even if it did, this would only
|
||||
// be profitable if there was tensor code in the unwinding code. Seems
|
||||
// unlikely.
|
||||
DEBUG_TRACE("throw %s", name(frame));
|
||||
return eval_frame_default(tstate, frame, throw_flag);
|
||||
}
|
||||
|
||||
CacheEntry* extra = get_extra(frame->f_code);
|
||||
if (extra == SKIP_CODE || (callback == Py_False && extra == NULL)) {
|
||||
DEBUG_TRACE("skip %s", name(frame));
|
||||
@ -715,6 +741,10 @@ static PyObject* _custom_eval_frame(
|
||||
// internal exception, returning here will leak the exception into user code
|
||||
// this is useful for debugging -- but we dont want it to happen outside of
|
||||
// testing
|
||||
// NB: we intentionally DO NOT re-enable custom behavior to prevent
|
||||
// cascading failure from internal exceptions. The upshot is if
|
||||
// Dynamo barfs, that's it for Dynamo, even if you catch the exception
|
||||
// inside the torch.compile block we won't try to Dynamo anything else.
|
||||
return NULL;
|
||||
} else if (result != Py_None) {
|
||||
DEBUG_TRACE("create cache %s", name(frame));
|
||||
|
@ -3070,6 +3070,29 @@ class TestCase(expecttest.TestCase):
|
||||
else:
|
||||
return super().assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)
|
||||
|
||||
# Verifies that no unraisable exceptions are raised by callable. Unlike regular
|
||||
# exceptions, these do not actually propagate to the caller and are
|
||||
# suppressed. We must test for them specially.
|
||||
def assertNoUnraisable(self, callable, *args, **kwargs):
|
||||
raised = None
|
||||
|
||||
def record_unraisable(unraisable):
|
||||
nonlocal raised
|
||||
raised = unraisable
|
||||
|
||||
# Disable GC when running the callable to prevent spurious flakiness
|
||||
# from unlucky GCs inside the callable
|
||||
prev = gc.isenabled()
|
||||
gc.disable()
|
||||
try:
|
||||
with unittest.mock.patch("sys.unraisablehook", record_unraisable):
|
||||
callable(*args, **kwargs)
|
||||
finally:
|
||||
if prev:
|
||||
gc.enable()
|
||||
|
||||
self.assertIsNone(raised)
|
||||
|
||||
# TODO: Support context manager interface
|
||||
# NB: The kwargs forwarding to callable robs the 'subname' parameter.
|
||||
# If you need it, manually apply your callable in a lambda instead.
|
||||
|
Reference in New Issue
Block a user