Implement generator.throw(exception) (#144424)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144424
Approved by: https://github.com/zou3519
ghstack dependencies: #141055, #144421, #144422, #144423
This commit is contained in:
Guilherme Leobas
2025-02-07 14:55:21 -03:00
committed by PyTorch MergeBot
parent 8ee095f7c1
commit 53ab82d8f5
3 changed files with 501 additions and 1 deletions

View File

@ -7,7 +7,7 @@ from collections import OrderedDict
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.exc import Unsupported
from torch._dynamo.exc import InternalTorchDynamoError, Unsupported
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -977,6 +977,255 @@ class TestGeneratorClose(GeneratorTestsBase):
self.assertEqual(z, 2)
class TestGeneratorThrow(GeneratorTestsBase):
def test_throw(self):
def whoo(t):
try:
yield t.sin()
except RuntimeError:
yield t.cos()
def fn(t):
gen = whoo(t)
a = next(gen)
b = gen.throw(RuntimeError)
return a + b
t = torch.randn(2)
y = self._compile_check(fn, (t,))
self.assertEqual(y, t.sin() + t.cos())
@unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE")
def test_throw_with_finally(self):
z = 0
def whoo():
nonlocal z
z = 0
try:
try:
yield 1
except ValueError:
yield 2
finally:
z += 2
except ValueError:
z += 33
yield 4
finally:
z += 1
z += 10
def f(x):
gen = whoo()
next(gen)
gen.throw(ValueError)
return x.sin()
self._compile_check(f)
self.assertEqual(z, 3)
def test_throw_without_finally(self):
z = 0
def whoo(t):
nonlocal z
z = 0
try:
z += 1
yield t.sin()
z += 10
except RuntimeError:
z += 100
yield t.cos()
z += 1_000
z += 10_000
def fn(t):
gen = whoo(t)
a = next(gen)
b = gen.throw(RuntimeError)
return a + b
t = torch.randn(2)
y = self._compile_check(fn, (t,))
self.assertEqual(y, t.sin() + t.cos())
self.assertEqual(z, 101)
def test_throw_three_arguments(self):
def whoo(t):
try:
yield t.sin()
except ValueError:
yield t.cos()
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
gen = whoo(t)
a = next(gen)
b = gen.throw(ValueError, "Error", None)
return a + b
t = torch.randn(2)
with self.assertRaises(InternalTorchDynamoError):
fn(t)
def test_throw_no_yield_after_throw(self):
z = 0
def whoo(t):
nonlocal z
z = 0
try:
z += 1
yield t.sin()
except ValueError:
z += 10
finally:
z += 100
def fn(t):
gen = whoo(t)
a = next(gen)
try:
gen.throw(ValueError)
except StopIteration:
return a
t = torch.randn(2)
y = self._compile_check(fn, (t,))
self.assertEqual(z, 111)
self.assertEqual(y, t.sin())
def test_throw_not_catch(self):
z = 0
def whoo(t):
nonlocal z
z = 0
try:
z += 1
yield t.sin()
except ValueError:
z += 10
yield t.cos()
finally:
z += 100
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
gen = whoo(t)
a = next(gen)
b = gen.throw(RuntimeError)
return a + b
t = torch.randn(2)
with self.assertRaises(RuntimeError):
fn(t)
def test_throw_raise_difference_exc(self):
z = 0
def whoo(t):
nonlocal z
z = 0
try:
z += 1
yield t.sin()
except ValueError as e:
z += 10
raise RuntimeError from e
finally:
z += 100
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
gen = whoo(t)
a = next(gen)
b = gen.throw(ValueError)
return a + b
t = torch.randn(2)
with self.assertRaises(RuntimeError):
fn(t)
def test_throw_yield_finally(self):
z = 0
def whoo(t):
nonlocal z
z = 0
try:
z += 1
yield t.sin()
except RuntimeError:
z += 10
yield t.cos()
finally:
z += 100
yield t.tan() # RuntimeError: generator ignored GeneratorExit
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
gen = whoo(t)
a = next(gen)
b = gen.throw(RuntimeError)
return a + b
t = torch.randn(2)
with self.assertRaises(Unsupported):
fn(t)
@unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE")
def test_throw_try_except_finally(self):
z = 0
def whoo(t):
nonlocal z
z = 0
try:
z += 1
yield t.sin()
except ValueError:
z += 10
yield t.cos()
except RuntimeError:
z += 100
yield t.tan()
finally:
z += 1000
z += 10_000
def fn(t):
gen = whoo(t)
a = next(gen)
b = gen.throw(RuntimeError)
return a + b
t = torch.randn(2)
y = self._compile_check(fn, (t,))
self.assertEqual(y, t.sin() + t.tan())
self.assertEqual(z, 1 + 100 + 1000)
def test_exception_context_with_yield(self):
def f():
yield
def fn(t):
gen = f()
gen.send(None)
try:
gen.throw(ValueError)
except ValueError:
z = 1
except Exception as e:
raise AssertionError from e
assert z == 1
return t.sin()
self._compile_check(fn)
class GeneratorCloseCPythonTests(GeneratorTestsBase):
# Taken from commit
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
@ -1118,6 +1367,149 @@ class GeneratorCloseCPythonTests(GeneratorTestsBase):
fn(t)
class GeneratorThrowCpythonTests(GeneratorTestsBase):
# Taken from commit
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
# changed the tests a little bit to run them inside dynamo
# + replaced all self.assert* calls to plain assert statements
@unittest.expectedFailure
def test_exception_context_with_yield(self):
def f():
try:
raise KeyError("a")
except Exception:
yield
def fn(t):
gen = f()
gen.send(None)
try:
gen.throw(ValueError)
except ValueError as e:
context = e.__context__
assert (type(context), context.args) == (KeyError, ("a",))
except Exception as e:
raise AssertionError from e
return t.sin()
self._compile_check(fn)
@unittest.expectedFailure
def test_exception_context_with_yield_inside_generator(self):
# Check that the context is also available from inside the generator
# with yield, as opposed to outside.
def f():
z = 0
try:
raise KeyError("a")
except Exception:
try:
yield
except Exception as exc:
z = 1
assert type(exc) == ValueError
context = exc.__context__
assert (type(context), context.args) == (KeyError, ("a",))
yield "b"
finally:
assert z == 1
def fn(t):
gen = f()
gen.send(None)
actual = gen.throw(ValueError)
# This ensures that the assertions inside were executed.
assert actual == "b"
return t.sin()
self._compile_check(fn)
@unittest.expectedFailure
def test_exception_context_with_yield_from(self):
def f():
yield
def g():
try:
raise KeyError("a")
except Exception:
yield from f()
def fn(t):
gen = g()
gen.send(None)
try:
gen.throw(ValueError)
except ValueError as e:
context = e.__context__
assert (type(context), context.args) == (KeyError, ("a",))
except Exception as e:
raise AssertionError from e
return t.sin()
self._compile_check(fn)
@unittest.skipIf(sys.version_info < (3, 12), "Test CLEANUP_THROW")
@unittest.expectedFailure
def test_exception_context_with_yield_from_with_context_cycle(self):
# Check trying to create an exception context cycle:
# https://bugs.python.org/issue40696
has_cycle = None
def f():
yield
def g(exc):
nonlocal has_cycle
try:
raise exc
except Exception:
try:
yield from f()
except Exception as exc:
has_cycle = exc is exc.__context__
yield
def fn(t):
exc = KeyError("a")
gen = g(exc)
gen.send(None)
gen.throw(exc)
# This also distinguishes from the initial has_cycle=None.
assert has_cycle is False
return t.sin()
self._compile_check(fn)
def test_throw_after_none_exc_type(self):
def g():
try:
raise KeyError
except KeyError:
pass
try:
yield
except Exception:
raise RuntimeError # noqa: B904
def fn(t):
gen = g()
gen.send(None)
z = 0
try:
gen.throw(ValueError)
except RuntimeError:
z += 1
except Exception:
raise AssertionError # noqa: B904
assert z == 1
return t.sin()
self._compile_check(fn)
class GeneratorCPythonTests(GeneratorTestsBase):
# Taken from commit
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10

View File

@ -309,6 +309,14 @@ observed_exception_map = {
}
def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedException]:
if exc_type not in observed_exception_map:
observed_exception_map[exc_type] = type(
f"Observed{exc_type.__name__}Error", (ObservedException,), {}
)
return observed_exception_map[exc_type]
def raise_observed_exception(
exc_type: type[Exception],
tx: InstructionTranslatorBase,

View File

@ -16,7 +16,9 @@ import torch
from .. import polyfills, variables
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
from ..exc import (
get_dynamo_observed_exception,
handle_observed_exception,
IncorrectUsage,
InfiniteGeneratorError,
ObservedException,
ObservedGeneratorExit,
@ -604,6 +606,104 @@ class LocalGeneratorObjectVariable(VariableTracker):
# https://github.com/python/cpython/pull/104771
assert tracer.symbolic_result is not None
return tracer.symbolic_result
elif name == "throw":
# * Raises an exception at the point where the generator was paused, and
# returns the next value yielded by the generator.
# * If the generator exits without yielding, raise StopIteration
# * If the generator function does not catch the passed-in exception,
# or raises a different exception, then that exception propagates to the caller.
if len(args) > 1:
raise IncorrectUsage(
"the (type, exc, tb) signature of throw() is deprecated, "
"use the single-arg signature instead."
)
# Setup the exception table and jump target in case of try...finally
tracer = self._get_inline_tracer(tx)
try:
self._setup_exception(tx, args[0])
except ObservedException:
# propagate the exception back to the parent caller
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
raise
retval = self.next_variable(tx)
# The exception raised before is still active. We need to check the exception
# table one more time to find the next target. But why? Lets walk
# through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M
#
# z = 0
# def whoo():
# global z
# z = 0
# try:
# yield 1
# except ValueError:
# yield 2
# finally:
# z += 1
# z += 10
#
# gen = whoo()
# next(gen)
# gen.throw(ValueError)
# print('z', z) -> z = 1
#
# ...
# >> 58 PUSH_EXC_INFO
#
# 8 60 LOAD_GLOBAL 2 (ValueError)
# 70 CHECK_EXC_MATCH
# 72 POP_JUMP_IF_FALSE 7 (to 88)
# 74 POP_TOP
#
# 9 76 LOAD_CONST 3 (2)
# 78 YIELD_VALUE 3 <------ ValueError is still active here
# 80 RESUME 1
# 82 POP_TOP
# 84 POP_EXCEPT
# 86 jump_backward 34 (to 20)
# ...
#
# ExceptionTable:
# 4 to 8 -> 124 [0] lasti
# 12 to 18 -> 58 [0]
# 20 to 56 -> 124 [0] lasti
# 58 to 82 -> 90 [1] lasti <------ move to 90
# 84 to 86 -> 96 [0]
# 88 to 88 -> 90 [1] lasti
# 90 to 94 -> 96 [0]
# 96 to 116 -> 118 [1] lasti
# 118 to 122 -> 124 [0] lasti
#
# In this scenario, a generator can yield after `throw()` is called. Even
# after the exception is raised a few lines above, it remains active
# within the `78 YIELD_VALUE` instruction. When the generator resumes
# after the second yield on instruction `80 RESUME`, we cannot simply
# return the control flow to the next instruction. Instead, one must
# check the exception table (or equivalent) to find the next target
# In this case, it says the instruction pointer must be moved to 90.
#
# Without this step, if we let the trace proceed to the next
# instruction, it would follow the control flow where the exception
# raised by `throw()` was handled and swallowed, potentially leading
# to incorrect behavior.
exc_type = type("__InternalThrowException", (Exception,), {})
try:
self._setup_exception(tx, variables.ExceptionVariable(exc_type, ()))
self.next_variable(tx)
except get_dynamo_observed_exception(exc_type):
# We should get back the exception raised before.
pass
except ObservedException:
# Propagate anything else back to the parent caller
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
else:
raise_observed_exception(RuntimeError, tracer)
return retval
super().call_method(tx, name, args, kwargs)