mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement generator.throw(exception)
(#144424)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144424 Approved by: https://github.com/zou3519 ghstack dependencies: #141055, #144421, #144422, #144423
This commit is contained in:
committed by
PyTorch MergeBot
parent
8ee095f7c1
commit
53ab82d8f5
@ -7,7 +7,7 @@ from collections import OrderedDict
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.exc import Unsupported
|
||||
from torch._dynamo.exc import InternalTorchDynamoError, Unsupported
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
@ -977,6 +977,255 @@ class TestGeneratorClose(GeneratorTestsBase):
|
||||
self.assertEqual(z, 2)
|
||||
|
||||
|
||||
class TestGeneratorThrow(GeneratorTestsBase):
|
||||
def test_throw(self):
|
||||
def whoo(t):
|
||||
try:
|
||||
yield t.sin()
|
||||
except RuntimeError:
|
||||
yield t.cos()
|
||||
|
||||
def fn(t):
|
||||
gen = whoo(t)
|
||||
a = next(gen)
|
||||
b = gen.throw(RuntimeError)
|
||||
return a + b
|
||||
|
||||
t = torch.randn(2)
|
||||
y = self._compile_check(fn, (t,))
|
||||
self.assertEqual(y, t.sin() + t.cos())
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE")
|
||||
def test_throw_with_finally(self):
|
||||
z = 0
|
||||
|
||||
def whoo():
|
||||
nonlocal z
|
||||
z = 0
|
||||
try:
|
||||
try:
|
||||
yield 1
|
||||
except ValueError:
|
||||
yield 2
|
||||
finally:
|
||||
z += 2
|
||||
except ValueError:
|
||||
z += 33
|
||||
yield 4
|
||||
finally:
|
||||
z += 1
|
||||
z += 10
|
||||
|
||||
def f(x):
|
||||
gen = whoo()
|
||||
next(gen)
|
||||
gen.throw(ValueError)
|
||||
return x.sin()
|
||||
|
||||
self._compile_check(f)
|
||||
self.assertEqual(z, 3)
|
||||
|
||||
def test_throw_without_finally(self):
|
||||
z = 0
|
||||
|
||||
def whoo(t):
|
||||
nonlocal z
|
||||
z = 0
|
||||
try:
|
||||
z += 1
|
||||
yield t.sin()
|
||||
z += 10
|
||||
except RuntimeError:
|
||||
z += 100
|
||||
yield t.cos()
|
||||
z += 1_000
|
||||
z += 10_000
|
||||
|
||||
def fn(t):
|
||||
gen = whoo(t)
|
||||
a = next(gen)
|
||||
b = gen.throw(RuntimeError)
|
||||
return a + b
|
||||
|
||||
t = torch.randn(2)
|
||||
y = self._compile_check(fn, (t,))
|
||||
self.assertEqual(y, t.sin() + t.cos())
|
||||
self.assertEqual(z, 101)
|
||||
|
||||
def test_throw_three_arguments(self):
|
||||
def whoo(t):
|
||||
try:
|
||||
yield t.sin()
|
||||
except ValueError:
|
||||
yield t.cos()
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = whoo(t)
|
||||
a = next(gen)
|
||||
b = gen.throw(ValueError, "Error", None)
|
||||
return a + b
|
||||
|
||||
t = torch.randn(2)
|
||||
with self.assertRaises(InternalTorchDynamoError):
|
||||
fn(t)
|
||||
|
||||
def test_throw_no_yield_after_throw(self):
|
||||
z = 0
|
||||
|
||||
def whoo(t):
|
||||
nonlocal z
|
||||
z = 0
|
||||
try:
|
||||
z += 1
|
||||
yield t.sin()
|
||||
except ValueError:
|
||||
z += 10
|
||||
finally:
|
||||
z += 100
|
||||
|
||||
def fn(t):
|
||||
gen = whoo(t)
|
||||
a = next(gen)
|
||||
try:
|
||||
gen.throw(ValueError)
|
||||
except StopIteration:
|
||||
return a
|
||||
|
||||
t = torch.randn(2)
|
||||
y = self._compile_check(fn, (t,))
|
||||
self.assertEqual(z, 111)
|
||||
self.assertEqual(y, t.sin())
|
||||
|
||||
def test_throw_not_catch(self):
|
||||
z = 0
|
||||
|
||||
def whoo(t):
|
||||
nonlocal z
|
||||
z = 0
|
||||
try:
|
||||
z += 1
|
||||
yield t.sin()
|
||||
except ValueError:
|
||||
z += 10
|
||||
yield t.cos()
|
||||
finally:
|
||||
z += 100
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = whoo(t)
|
||||
a = next(gen)
|
||||
b = gen.throw(RuntimeError)
|
||||
return a + b
|
||||
|
||||
t = torch.randn(2)
|
||||
with self.assertRaises(RuntimeError):
|
||||
fn(t)
|
||||
|
||||
def test_throw_raise_difference_exc(self):
|
||||
z = 0
|
||||
|
||||
def whoo(t):
|
||||
nonlocal z
|
||||
z = 0
|
||||
try:
|
||||
z += 1
|
||||
yield t.sin()
|
||||
except ValueError as e:
|
||||
z += 10
|
||||
raise RuntimeError from e
|
||||
finally:
|
||||
z += 100
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = whoo(t)
|
||||
a = next(gen)
|
||||
b = gen.throw(ValueError)
|
||||
return a + b
|
||||
|
||||
t = torch.randn(2)
|
||||
with self.assertRaises(RuntimeError):
|
||||
fn(t)
|
||||
|
||||
def test_throw_yield_finally(self):
|
||||
z = 0
|
||||
|
||||
def whoo(t):
|
||||
nonlocal z
|
||||
z = 0
|
||||
try:
|
||||
z += 1
|
||||
yield t.sin()
|
||||
except RuntimeError:
|
||||
z += 10
|
||||
yield t.cos()
|
||||
finally:
|
||||
z += 100
|
||||
yield t.tan() # RuntimeError: generator ignored GeneratorExit
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = whoo(t)
|
||||
a = next(gen)
|
||||
b = gen.throw(RuntimeError)
|
||||
return a + b
|
||||
|
||||
t = torch.randn(2)
|
||||
with self.assertRaises(Unsupported):
|
||||
fn(t)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE")
|
||||
def test_throw_try_except_finally(self):
|
||||
z = 0
|
||||
|
||||
def whoo(t):
|
||||
nonlocal z
|
||||
z = 0
|
||||
try:
|
||||
z += 1
|
||||
yield t.sin()
|
||||
except ValueError:
|
||||
z += 10
|
||||
yield t.cos()
|
||||
except RuntimeError:
|
||||
z += 100
|
||||
yield t.tan()
|
||||
finally:
|
||||
z += 1000
|
||||
z += 10_000
|
||||
|
||||
def fn(t):
|
||||
gen = whoo(t)
|
||||
a = next(gen)
|
||||
b = gen.throw(RuntimeError)
|
||||
return a + b
|
||||
|
||||
t = torch.randn(2)
|
||||
y = self._compile_check(fn, (t,))
|
||||
self.assertEqual(y, t.sin() + t.tan())
|
||||
self.assertEqual(z, 1 + 100 + 1000)
|
||||
|
||||
def test_exception_context_with_yield(self):
|
||||
def f():
|
||||
yield
|
||||
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
try:
|
||||
gen.throw(ValueError)
|
||||
except ValueError:
|
||||
z = 1
|
||||
except Exception as e:
|
||||
raise AssertionError from e
|
||||
assert z == 1
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
|
||||
class GeneratorCloseCPythonTests(GeneratorTestsBase):
|
||||
# Taken from commit
|
||||
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
||||
@ -1118,6 +1367,149 @@ class GeneratorCloseCPythonTests(GeneratorTestsBase):
|
||||
fn(t)
|
||||
|
||||
|
||||
class GeneratorThrowCpythonTests(GeneratorTestsBase):
|
||||
# Taken from commit
|
||||
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
||||
# changed the tests a little bit to run them inside dynamo
|
||||
# + replaced all self.assert* calls to plain assert statements
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_exception_context_with_yield(self):
|
||||
def f():
|
||||
try:
|
||||
raise KeyError("a")
|
||||
except Exception:
|
||||
yield
|
||||
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
try:
|
||||
gen.throw(ValueError)
|
||||
except ValueError as e:
|
||||
context = e.__context__
|
||||
assert (type(context), context.args) == (KeyError, ("a",))
|
||||
except Exception as e:
|
||||
raise AssertionError from e
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_exception_context_with_yield_inside_generator(self):
|
||||
# Check that the context is also available from inside the generator
|
||||
# with yield, as opposed to outside.
|
||||
def f():
|
||||
z = 0
|
||||
try:
|
||||
raise KeyError("a")
|
||||
except Exception:
|
||||
try:
|
||||
yield
|
||||
except Exception as exc:
|
||||
z = 1
|
||||
assert type(exc) == ValueError
|
||||
context = exc.__context__
|
||||
assert (type(context), context.args) == (KeyError, ("a",))
|
||||
yield "b"
|
||||
finally:
|
||||
assert z == 1
|
||||
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
actual = gen.throw(ValueError)
|
||||
# This ensures that the assertions inside were executed.
|
||||
assert actual == "b"
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_exception_context_with_yield_from(self):
|
||||
def f():
|
||||
yield
|
||||
|
||||
def g():
|
||||
try:
|
||||
raise KeyError("a")
|
||||
except Exception:
|
||||
yield from f()
|
||||
|
||||
def fn(t):
|
||||
gen = g()
|
||||
gen.send(None)
|
||||
try:
|
||||
gen.throw(ValueError)
|
||||
except ValueError as e:
|
||||
context = e.__context__
|
||||
assert (type(context), context.args) == (KeyError, ("a",))
|
||||
except Exception as e:
|
||||
raise AssertionError from e
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 12), "Test CLEANUP_THROW")
|
||||
@unittest.expectedFailure
|
||||
def test_exception_context_with_yield_from_with_context_cycle(self):
|
||||
# Check trying to create an exception context cycle:
|
||||
# https://bugs.python.org/issue40696
|
||||
has_cycle = None
|
||||
|
||||
def f():
|
||||
yield
|
||||
|
||||
def g(exc):
|
||||
nonlocal has_cycle
|
||||
try:
|
||||
raise exc
|
||||
except Exception:
|
||||
try:
|
||||
yield from f()
|
||||
except Exception as exc:
|
||||
has_cycle = exc is exc.__context__
|
||||
yield
|
||||
|
||||
def fn(t):
|
||||
exc = KeyError("a")
|
||||
gen = g(exc)
|
||||
gen.send(None)
|
||||
gen.throw(exc)
|
||||
# This also distinguishes from the initial has_cycle=None.
|
||||
assert has_cycle is False
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
def test_throw_after_none_exc_type(self):
|
||||
def g():
|
||||
try:
|
||||
raise KeyError
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
yield
|
||||
except Exception:
|
||||
raise RuntimeError # noqa: B904
|
||||
|
||||
def fn(t):
|
||||
gen = g()
|
||||
gen.send(None)
|
||||
z = 0
|
||||
try:
|
||||
gen.throw(ValueError)
|
||||
except RuntimeError:
|
||||
z += 1
|
||||
except Exception:
|
||||
raise AssertionError # noqa: B904
|
||||
assert z == 1
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
|
||||
class GeneratorCPythonTests(GeneratorTestsBase):
|
||||
# Taken from commit
|
||||
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
||||
|
@ -309,6 +309,14 @@ observed_exception_map = {
|
||||
}
|
||||
|
||||
|
||||
def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedException]:
|
||||
if exc_type not in observed_exception_map:
|
||||
observed_exception_map[exc_type] = type(
|
||||
f"Observed{exc_type.__name__}Error", (ObservedException,), {}
|
||||
)
|
||||
return observed_exception_map[exc_type]
|
||||
|
||||
|
||||
def raise_observed_exception(
|
||||
exc_type: type[Exception],
|
||||
tx: InstructionTranslatorBase,
|
||||
|
@ -16,7 +16,9 @@ import torch
|
||||
from .. import polyfills, variables
|
||||
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
|
||||
from ..exc import (
|
||||
get_dynamo_observed_exception,
|
||||
handle_observed_exception,
|
||||
IncorrectUsage,
|
||||
InfiniteGeneratorError,
|
||||
ObservedException,
|
||||
ObservedGeneratorExit,
|
||||
@ -604,6 +606,104 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
||||
# https://github.com/python/cpython/pull/104771
|
||||
assert tracer.symbolic_result is not None
|
||||
return tracer.symbolic_result
|
||||
elif name == "throw":
|
||||
# * Raises an exception at the point where the generator was paused, and
|
||||
# returns the next value yielded by the generator.
|
||||
# * If the generator exits without yielding, raise StopIteration
|
||||
# * If the generator function does not catch the passed-in exception,
|
||||
# or raises a different exception, then that exception propagates to the caller.
|
||||
|
||||
if len(args) > 1:
|
||||
raise IncorrectUsage(
|
||||
"the (type, exc, tb) signature of throw() is deprecated, "
|
||||
"use the single-arg signature instead."
|
||||
)
|
||||
|
||||
# Setup the exception table and jump target in case of try...finally
|
||||
tracer = self._get_inline_tracer(tx)
|
||||
try:
|
||||
self._setup_exception(tx, args[0])
|
||||
except ObservedException:
|
||||
# propagate the exception back to the parent caller
|
||||
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
|
||||
raise
|
||||
|
||||
retval = self.next_variable(tx)
|
||||
|
||||
# The exception raised before is still active. We need to check the exception
|
||||
# table one more time to find the next target. But why? Let’s walk
|
||||
# through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M
|
||||
#
|
||||
# z = 0
|
||||
# def whoo():
|
||||
# global z
|
||||
# z = 0
|
||||
# try:
|
||||
# yield 1
|
||||
# except ValueError:
|
||||
# yield 2
|
||||
# finally:
|
||||
# z += 1
|
||||
# z += 10
|
||||
#
|
||||
# gen = whoo()
|
||||
# next(gen)
|
||||
# gen.throw(ValueError)
|
||||
# print('z', z) -> z = 1
|
||||
#
|
||||
# ...
|
||||
# >> 58 PUSH_EXC_INFO
|
||||
#
|
||||
# 8 60 LOAD_GLOBAL 2 (ValueError)
|
||||
# 70 CHECK_EXC_MATCH
|
||||
# 72 POP_JUMP_IF_FALSE 7 (to 88)
|
||||
# 74 POP_TOP
|
||||
#
|
||||
# 9 76 LOAD_CONST 3 (2)
|
||||
# 78 YIELD_VALUE 3 <------ ValueError is still active here
|
||||
# 80 RESUME 1
|
||||
# 82 POP_TOP
|
||||
# 84 POP_EXCEPT
|
||||
# 86 jump_backward 34 (to 20)
|
||||
# ...
|
||||
#
|
||||
# ExceptionTable:
|
||||
# 4 to 8 -> 124 [0] lasti
|
||||
# 12 to 18 -> 58 [0]
|
||||
# 20 to 56 -> 124 [0] lasti
|
||||
# 58 to 82 -> 90 [1] lasti <------ move to 90
|
||||
# 84 to 86 -> 96 [0]
|
||||
# 88 to 88 -> 90 [1] lasti
|
||||
# 90 to 94 -> 96 [0]
|
||||
# 96 to 116 -> 118 [1] lasti
|
||||
# 118 to 122 -> 124 [0] lasti
|
||||
#
|
||||
# In this scenario, a generator can yield after `throw()` is called. Even
|
||||
# after the exception is raised a few lines above, it remains active
|
||||
# within the `78 YIELD_VALUE` instruction. When the generator resumes
|
||||
# after the second yield on instruction `80 RESUME`, we cannot simply
|
||||
# return the control flow to the next instruction. Instead, one must
|
||||
# check the exception table (or equivalent) to find the next target
|
||||
# In this case, it says the instruction pointer must be moved to 90.
|
||||
#
|
||||
# Without this step, if we let the trace proceed to the next
|
||||
# instruction, it would follow the control flow where the exception
|
||||
# raised by `throw()` was handled and swallowed, potentially leading
|
||||
# to incorrect behavior.
|
||||
exc_type = type("__InternalThrowException", (Exception,), {})
|
||||
|
||||
try:
|
||||
self._setup_exception(tx, variables.ExceptionVariable(exc_type, ()))
|
||||
self.next_variable(tx)
|
||||
except get_dynamo_observed_exception(exc_type):
|
||||
# We should get back the exception raised before.
|
||||
pass
|
||||
except ObservedException:
|
||||
# Propagate anything else back to the parent caller
|
||||
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
|
||||
else:
|
||||
raise_observed_exception(RuntimeError, tracer)
|
||||
return retval
|
||||
|
||||
super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user