mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Correctly propagate exception to parent tx (#146502)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146502 Approved by: https://github.com/anijain2305, https://github.com/williamwen42, https://github.com/zou3519 ghstack dependencies: #146504, #146499
This commit is contained in:
committed by
PyTorch MergeBot
parent
fb53e9e514
commit
daff65d671
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
@ -11,7 +12,11 @@ import torch.nn
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.bytecode_transformation import Instruction
|
||||
from torch._dynamo.symbolic_convert import SpeculationLog, SpeculationLogDivergence
|
||||
from torch.testing._internal.common_utils import make_dynamo_test
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
make_dynamo_test,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
|
||||
class CustomException(Exception):
|
||||
@ -123,6 +128,33 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_propagate_exception_inside_ctx_manager(self):
|
||||
@contextlib.contextmanager
|
||||
def cm():
|
||||
try:
|
||||
yield
|
||||
except BaseException:
|
||||
raise ValueError # noqa: B904
|
||||
|
||||
@contextlib.contextmanager
|
||||
def nothing():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
z = 0
|
||||
with nothing():
|
||||
try:
|
||||
with cm():
|
||||
raise IndexError
|
||||
except ValueError:
|
||||
z = 1
|
||||
except IndexError:
|
||||
z = 2
|
||||
assert z == 1
|
||||
|
||||
def test_exception_else(self):
|
||||
def gn(x):
|
||||
return torch.cos(x)
|
||||
@ -145,6 +177,64 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
@make_dynamo_test
|
||||
def test_raise_match(self):
|
||||
a = AttributeError
|
||||
b = BytesWarning
|
||||
c = ConnectionError
|
||||
d = DeprecationWarning
|
||||
e = Exception
|
||||
|
||||
def fn(a, b):
|
||||
try:
|
||||
raise a
|
||||
finally:
|
||||
raise b
|
||||
|
||||
def fix_exc_context(frame_exc, new_exc, old_exc):
|
||||
# slightly change from ExitStack.fix_exc_context function
|
||||
while 1:
|
||||
exc_context = new_exc.__context__
|
||||
if exc_context is None or exc_context is old_exc:
|
||||
return
|
||||
if exc_context is frame_exc:
|
||||
break
|
||||
new_exc = exc_context
|
||||
new_exc.__context__ = old_exc
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ctx():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
frame_exc = prev_exc = sys.exc_info()
|
||||
args = [(d, c), (b, a)]
|
||||
for x, y in args:
|
||||
try:
|
||||
fn(x, y)
|
||||
except BaseException:
|
||||
new_exc = sys.exc_info()
|
||||
fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1])
|
||||
prev_exc = new_exc
|
||||
|
||||
try:
|
||||
fixed_ctx = prev_exc[1].__context__
|
||||
raise prev_exc[1]
|
||||
except BaseException:
|
||||
prev_exc[1].__context__ = fixed_ctx
|
||||
raise
|
||||
|
||||
try:
|
||||
with ctx():
|
||||
raise e
|
||||
except Exception as exc:
|
||||
assert isinstance(exc, a)
|
||||
assert isinstance(exc.__context__, b)
|
||||
assert isinstance(exc.__context__.__context__, c)
|
||||
assert isinstance(exc.__context__.__context__.__context__, d)
|
||||
assert isinstance(exc.__context__.__context__.__context__.__context__, e)
|
||||
|
||||
# TODO(anijain2305) - does not work with fullgraph=True
|
||||
def test_exception_with_another_exception2(self):
|
||||
def gn(x):
|
||||
@ -455,6 +545,103 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(ref[0], res[0])
|
||||
self.assertEqual(ref[1], res[1])
|
||||
|
||||
@make_dynamo_test
|
||||
def test_reraise_first_exc(self):
|
||||
def fn():
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except ZeroDivisionError:
|
||||
try:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
pass
|
||||
raise
|
||||
|
||||
try:
|
||||
fn()
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
assert sys.exc_info()[0] is None
|
||||
|
||||
@make_dynamo_test
|
||||
def test_ensure_exception_is_active_after_try_except_block(self):
|
||||
try:
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except ZeroDivisionError:
|
||||
for exc in (KeyError, IndexError):
|
||||
try:
|
||||
raise exc
|
||||
except exc:
|
||||
pass
|
||||
raise
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
assert sys.exc_info()[0] is None
|
||||
|
||||
@make_dynamo_test
|
||||
def test_ensure_exception_is_active_inside_try_except_block(self):
|
||||
try:
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except ZeroDivisionError:
|
||||
for exc in (KeyError, IndexError):
|
||||
try:
|
||||
raise exc
|
||||
except exc as e:
|
||||
assert isinstance(e.__context__, ZeroDivisionError)
|
||||
raise
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
assert sys.exc_info()[0] is None
|
||||
|
||||
@make_dynamo_test
|
||||
def test_handle_all_exceptions(self):
|
||||
def cm():
|
||||
try:
|
||||
yield 1
|
||||
except ValueError:
|
||||
try:
|
||||
raise TypeError
|
||||
finally:
|
||||
pass
|
||||
|
||||
try:
|
||||
gen = cm()
|
||||
next(gen)
|
||||
gen.throw(ValueError)
|
||||
except TypeError:
|
||||
pass
|
||||
assert sys.exc_info()[0] is None
|
||||
|
||||
@make_dynamo_test
|
||||
def test_reraise(self):
|
||||
try:
|
||||
try:
|
||||
raise ValueError
|
||||
except ValueError: # noqa: TRY203
|
||||
raise
|
||||
except ValueError:
|
||||
pass
|
||||
assert sys.exc_info()[0] is None
|
||||
|
||||
@make_dynamo_test
|
||||
def test_raise_finally_simple(self):
|
||||
def fn():
|
||||
try:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
try:
|
||||
raise TypeError
|
||||
finally:
|
||||
pass
|
||||
|
||||
try:
|
||||
fn()
|
||||
except TypeError:
|
||||
pass
|
||||
assert sys.exc_info()[0] is None
|
||||
|
||||
def test_reconstruct___context__(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
@ -574,6 +761,54 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
with self.assertRaisesRegex(TypeError, "exception cause must be"):
|
||||
fn(t, e)
|
||||
|
||||
@parametrize(
|
||||
"ex",
|
||||
[TypeError, CustomException],
|
||||
name_fn=lambda x: x.__name__,
|
||||
)
|
||||
@make_dynamo_test
|
||||
def test_set___cause__(self, ex):
|
||||
def fn():
|
||||
try:
|
||||
raise ex
|
||||
except ex:
|
||||
raise TypeError from None
|
||||
|
||||
try:
|
||||
fn()
|
||||
except TypeError as e:
|
||||
assert isinstance(e.__context__, ex)
|
||||
assert e.__cause__ is None
|
||||
assert e.__suppress_context__ is True
|
||||
|
||||
@parametrize(
|
||||
"ex",
|
||||
[RuntimeError, CustomException],
|
||||
name_fn=lambda x: x.__name__,
|
||||
)
|
||||
@make_dynamo_test
|
||||
def test_set___cause___error(self, ex):
|
||||
def fn():
|
||||
try:
|
||||
raise ex
|
||||
except Exception as e:
|
||||
e.__cause__ = 2
|
||||
raise
|
||||
|
||||
z = 0
|
||||
|
||||
try:
|
||||
fn()
|
||||
except TypeError as e:
|
||||
z = 1
|
||||
assert e.args == (
|
||||
"exception cause must be None or derive from BaseException",
|
||||
)
|
||||
except Exception:
|
||||
raise AssertionError from None
|
||||
|
||||
assert z == 1
|
||||
|
||||
def test_user_defined_exception_variable(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
@ -852,6 +1087,9 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertIs(a.__context__, c)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ExceptionTests)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
@ -7,7 +7,7 @@ from collections import OrderedDict
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.exc import InternalTorchDynamoError, Unsupported
|
||||
from torch._dynamo.exc import Unsupported
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -1239,7 +1239,6 @@ class TestGeneratorThrow(GeneratorTestsBase):
|
||||
y = self._compile_check(fn, (t,))
|
||||
self.assertEqual(y, t.sin() + t.cos())
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE")
|
||||
def test_throw_with_finally(self):
|
||||
z = 0
|
||||
|
||||
@ -1296,24 +1295,6 @@ class TestGeneratorThrow(GeneratorTestsBase):
|
||||
self.assertEqual(y, t.sin() + t.cos())
|
||||
self.assertEqual(z, 101)
|
||||
|
||||
def test_throw_three_arguments(self):
|
||||
def whoo(t):
|
||||
try:
|
||||
yield t.sin()
|
||||
except ValueError:
|
||||
yield t.cos()
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = whoo(t)
|
||||
a = next(gen)
|
||||
b = gen.throw(ValueError, "Error", None)
|
||||
return a + b
|
||||
|
||||
t = torch.randn(2)
|
||||
with self.assertRaises(InternalTorchDynamoError):
|
||||
fn(t)
|
||||
|
||||
def test_throw_no_yield_after_throw(self):
|
||||
z = 0
|
||||
|
||||
@ -1420,7 +1401,6 @@ class TestGeneratorThrow(GeneratorTestsBase):
|
||||
with self.assertRaises(Unsupported):
|
||||
fn(t)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE")
|
||||
def test_throw_try_except_finally(self):
|
||||
z = 0
|
||||
|
||||
@ -1638,7 +1618,6 @@ class GeneratorThrowCpythonTests(GeneratorTestsBase):
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE")
|
||||
def test_exception_context_with_yield_inside_generator(self):
|
||||
# Check that the context is also available from inside the generator
|
||||
# with yield, as opposed to outside.
|
||||
|
@ -118,6 +118,7 @@ from .replay_record import ExecutionRecord
|
||||
from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
|
||||
from .symbolic_convert import (
|
||||
DistributedState,
|
||||
ExceptionStack,
|
||||
InstructionTranslator,
|
||||
LocalState,
|
||||
SpeculationLog,
|
||||
@ -689,6 +690,7 @@ def _compile(
|
||||
nonlocal output
|
||||
nonlocal tracer
|
||||
speculation_log.restart()
|
||||
exn_vt_stack = ExceptionStack()
|
||||
tracer = InstructionTranslator(
|
||||
instructions,
|
||||
code,
|
||||
@ -704,6 +706,7 @@ def _compile(
|
||||
export_constraints,
|
||||
frame_state=frame_state,
|
||||
speculation_log=speculation_log,
|
||||
exn_vt_stack=exn_vt_stack,
|
||||
distributed_state=distributed_state,
|
||||
)
|
||||
|
||||
|
@ -362,7 +362,7 @@ def raise_observed_exception(
|
||||
# CPython here raises an exception. Since there is no python code, we have to manually setup the exception
|
||||
# stack and raise the exception.
|
||||
exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type]
|
||||
tx.exn_vt_stack.append(exception_vt)
|
||||
tx.exn_vt_stack.set_current_exception(exception_vt)
|
||||
raise observed_exception_map[exc_type]
|
||||
|
||||
|
||||
@ -391,7 +391,7 @@ def handle_observed_exception(tx: Any) -> None:
|
||||
#
|
||||
|
||||
# Fortunately this translates to a simple pop from the exn_vt_stack
|
||||
tx.exn_vt_stack.pop()
|
||||
tx.exn_vt_stack.clear_current_exception()
|
||||
|
||||
|
||||
# These exceptions are ok to fallback to eager/graph_break.
|
||||
|
@ -44,7 +44,7 @@ import traceback
|
||||
import types
|
||||
import typing
|
||||
import weakref
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
from typing import Any, Callable, cast, NoReturn, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -383,9 +383,11 @@ class BlockStackEntry:
|
||||
and hasattr(self.with_context, "target_values")
|
||||
and self.with_context.target_values
|
||||
):
|
||||
return ReenterWith(self.stack_index, tuple(self.with_context.target_values))
|
||||
return ReenterWith(
|
||||
self.stack_index - 1, tuple(self.with_context.target_values)
|
||||
)
|
||||
else:
|
||||
return ReenterWith(self.stack_index)
|
||||
return ReenterWith(self.stack_index - 1)
|
||||
|
||||
def exit(self, tx, is_graph_break):
|
||||
assert self.with_context is not None
|
||||
@ -956,6 +958,100 @@ class BytecodeDistpatchTableMeta(type):
|
||||
cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExceptionStack:
|
||||
"""
|
||||
Exception stack that it is shared among all InstructionTranslator instances
|
||||
"""
|
||||
|
||||
# Exception handling in CPython is a bit confusing and some of the bytecode
|
||||
# have a slightly different behavior than what is is documented. While reading
|
||||
# the documentation, is important to notice that the terms "current exception"
|
||||
# and "stack" sometimes refers to a C variable with the same name and the
|
||||
# exception stack, respectively.
|
||||
#
|
||||
# The lifetime of an exception is (Python 3.11+):
|
||||
# + tx._raise_exception_variable(...) := sets the current_exception variable
|
||||
# + PUSH_EXC_INFO := pushes the current_exception to the *exception stack*
|
||||
# + POP_EXCEPT := pops TOS from the *exception stack*
|
||||
|
||||
_exc_stack: list[VariableTracker] = dataclasses.field(default_factory=list)
|
||||
_current_exception: Optional[VariableTracker] = dataclasses.field(default=None)
|
||||
|
||||
def clear_current_exception(self):
|
||||
self._current_exception = None
|
||||
|
||||
def set_current_exception(self, val):
|
||||
self._set_context_and_break_context_reference_cycle(val)
|
||||
self._current_exception = val
|
||||
|
||||
def move_current_exception_to_stack(self):
|
||||
assert self._current_exception is not None
|
||||
self.append(self._current_exception)
|
||||
self.clear_current_exception()
|
||||
|
||||
def get_current_exception(self):
|
||||
assert self._current_exception is not None
|
||||
return self._current_exception
|
||||
|
||||
def _set_context_recursive(self, val, prev_idx):
|
||||
if (ctx := val.__context__) and type(ctx) is not ConstantVariable:
|
||||
return val
|
||||
if len(self._exc_stack) + prev_idx > 0:
|
||||
prev = self._exc_stack[prev_idx]
|
||||
self._set_context_recursive(prev, prev_idx - 1)
|
||||
val.set_context(prev)
|
||||
return val
|
||||
|
||||
def _break_context_reference_cycle(self, val):
|
||||
# See test_exceptions::test_raise_does_not_create_context_chain_cycle
|
||||
# Based on https://github.com/python/cpython/blob/e635bf2e49797ecb976ce45a67fce2201a25ca68/Python/errors.c#L207-L228
|
||||
# As noted on CPython, this is O(chain length) but the context chains
|
||||
# are usually very small
|
||||
o = slow_o = val
|
||||
slow_update_toggle = False # floyd's algorithm for detecting cycle
|
||||
while True:
|
||||
context = o.__context__
|
||||
if type(context) is ConstantVariable: # context not set
|
||||
break
|
||||
|
||||
if context is val:
|
||||
o.set_context(ConstantVariable(None))
|
||||
break
|
||||
|
||||
o = context
|
||||
if o is slow_o:
|
||||
# pre-existing cycle - all exceptions on the path were
|
||||
# visited and checked
|
||||
break
|
||||
|
||||
if slow_update_toggle:
|
||||
slow_o = slow_o.__context__ # visited all exceptions
|
||||
slow_update_toggle = not slow_update_toggle
|
||||
|
||||
def _set_context_and_break_context_reference_cycle(self, val):
|
||||
# set Exception.__context__
|
||||
self._set_context_recursive(val, len(self._exc_stack) - 1)
|
||||
self._break_context_reference_cycle(val)
|
||||
|
||||
def pop(self):
|
||||
return self._exc_stack.pop()
|
||||
|
||||
def append(self, val):
|
||||
self._exc_stack.append(val)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._exc_stack)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self._exc_stack[index]
|
||||
|
||||
def __str__(self):
|
||||
return f"{self._exc_stack=} - {self._current_exception=}"
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
class InstructionTranslatorBase(
|
||||
metaclass=BytecodeDistpatchTableMeta,
|
||||
):
|
||||
@ -975,7 +1071,7 @@ class InstructionTranslatorBase(
|
||||
inconsistent_side_effects: bool
|
||||
current_speculation: Optional[SpeculationEntry]
|
||||
dispatch_table: list[Any]
|
||||
exn_vt_stack: list[VariableTracker]
|
||||
exn_vt_stack: ExceptionStack
|
||||
exec_recorder: Optional[ExecutionRecorder]
|
||||
strict_checks_fn: Optional[Callable[[VariableTracker], bool]]
|
||||
start_point: Optional[int]
|
||||
@ -1655,43 +1751,7 @@ class InstructionTranslatorBase(
|
||||
self.push(ConstantVariable.create(None))
|
||||
self.jump(inst)
|
||||
|
||||
def _raise_exception_variable(self, inst):
|
||||
def set_context_recursive(val, prev_idx):
|
||||
if (ctx := val.__context__) and type(ctx) is not ConstantVariable:
|
||||
return val
|
||||
if len(self.exn_vt_stack) + prev_idx > 0:
|
||||
prev = self.exn_vt_stack[prev_idx]
|
||||
set_context_recursive(prev, prev_idx - 1)
|
||||
val.set_context(prev)
|
||||
return val
|
||||
|
||||
def break_context_reference_cycle(val):
|
||||
# See test_exceptions::test_raise_does_not_create_context_chain_cycle
|
||||
# Based on https://github.com/python/cpython/blob/e635bf2e49797ecb976ce45a67fce2201a25ca68/Python/errors.c#L207-L228
|
||||
# As noted on CPython, this is O(chain length) but the context chains
|
||||
# are usually very small
|
||||
o = slow_o = val
|
||||
slow_update_toggle = False # floyd's algorithm for detecting cycle
|
||||
while True:
|
||||
context = o.__context__
|
||||
if type(context) is ConstantVariable: # context not set
|
||||
break
|
||||
|
||||
if context is val:
|
||||
o.set_context(ConstantVariable(None))
|
||||
break
|
||||
|
||||
o = context
|
||||
if o is slow_o:
|
||||
# pre-existing cycle - all exceptions on the path were
|
||||
# visited and checked
|
||||
break
|
||||
|
||||
if slow_update_toggle:
|
||||
slow_o = slow_o.__context__ # visited all exceptions
|
||||
slow_update_toggle = not slow_update_toggle
|
||||
|
||||
val = self.pop()
|
||||
def _raise_exception_variable(self, val) -> NoReturn:
|
||||
# User can raise exception in 2 ways
|
||||
# 1) raise exception type - raise NotImplementedError
|
||||
# 2) raise execption instance - raise NotImplemetedError("foo")
|
||||
@ -1705,6 +1765,7 @@ class InstructionTranslatorBase(
|
||||
val = val.call_function(self, [], {}) # type: ignore[arg-type]
|
||||
|
||||
# Handle https://peps.python.org/pep-0479/
|
||||
# CPython 3.12+ has a specific bytecode instruction (CALL_INTRINSIC_1 3) for this
|
||||
if (
|
||||
is_generator(self.f_code)
|
||||
and isinstance(val, variables.ExceptionVariable)
|
||||
@ -1712,12 +1773,8 @@ class InstructionTranslatorBase(
|
||||
):
|
||||
val = variables.BuiltinVariable(RuntimeError).call_function(self, [], {}) # type: ignore[arg-type]
|
||||
|
||||
# set Exception.__context__
|
||||
set_context_recursive(val, len(self.exn_vt_stack) - 1)
|
||||
break_context_reference_cycle(val)
|
||||
|
||||
# Save the exception in a global data structure
|
||||
self.exn_vt_stack.append(val)
|
||||
self.exn_vt_stack.set_current_exception(val)
|
||||
|
||||
# 2) when user raises exception instance
|
||||
if self._isinstance_exception(val):
|
||||
@ -1732,26 +1789,29 @@ class InstructionTranslatorBase(
|
||||
|
||||
def RAISE_VARARGS(self, inst):
|
||||
if inst.arg == 0:
|
||||
# duplicate the top of the stack and re-raise it
|
||||
if sys.version_info < (3, 11):
|
||||
unimplemented_v2(
|
||||
gb_type="Re-raise with no arguments",
|
||||
context="",
|
||||
explanation="Dynamo does not support re-raising the previous exception "
|
||||
"in Python < 3.11 (empty `raise`)",
|
||||
hints=[],
|
||||
)
|
||||
assert self._isinstance_exception(self.stack[-1])
|
||||
self.stack.append(self.stack[-1])
|
||||
self._raise_exception_variable(inst)
|
||||
# re-raise the previous exception. Here CPython refers to the exception
|
||||
# on top of the exception stack
|
||||
assert len(self.exn_vt_stack)
|
||||
val = self.exn_vt_stack[-1]
|
||||
assert self._isinstance_exception(val), val
|
||||
self._raise_exception_variable(val)
|
||||
elif inst.arg == 1:
|
||||
self._raise_exception_variable(inst)
|
||||
# raise TOS
|
||||
val = self.stack[-1]
|
||||
self._raise_exception_variable(val)
|
||||
else:
|
||||
# Support raise .. from None ... Dynamo does not track __cause__ and other attributes of exception. So we
|
||||
# ignore `from None` part.
|
||||
# raise .. from None
|
||||
from_vt = self.pop()
|
||||
if isinstance(from_vt, ConstantVariable) and from_vt.value is None:
|
||||
self._raise_exception_variable(inst)
|
||||
val = self.pop()
|
||||
try:
|
||||
self._raise_exception_variable(val)
|
||||
finally:
|
||||
# Update __cause__/__supppress_context__ in the raised exception
|
||||
curr_exc = self.exn_vt_stack.get_current_exception()
|
||||
curr_exc.call_setattr(
|
||||
self, ConstantVariable("__cause__"), ConstantVariable(None)
|
||||
)
|
||||
unimplemented_v2(
|
||||
gb_type="Re-raise with 2 arguments",
|
||||
context=str(from_vt),
|
||||
@ -1774,16 +1834,27 @@ class InstructionTranslatorBase(
|
||||
self.RERAISE(inst)
|
||||
|
||||
def RERAISE(self, inst):
|
||||
# https://docs.python.org/3/library/dis.html#opcode-RERAISE
|
||||
# Re-raises the exception currently on top of the stack. If oparg is
|
||||
# non-zero, pops an additional value from the stack which is used to
|
||||
# set f_lasti of the current frame.
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
# RERAISE is currently supported in a narrow case of `raise ... from None`
|
||||
self._raise_exception_variable(inst)
|
||||
unimplemented_v2(
|
||||
gb_type="RERAISE in Python < 3.11",
|
||||
context="",
|
||||
explanation="RERAISE bytecode (https://docs.python.org/3.10/library/dis.html#opcode-RERAISE) "
|
||||
"not supported in Python < 3.11. This bytecode is generated by try/with blocks.",
|
||||
hints=[],
|
||||
)
|
||||
val = self.pop()
|
||||
if inst.argval:
|
||||
# RERAISE 1
|
||||
_ = self.pop()
|
||||
self._raise_exception_variable(val)
|
||||
else:
|
||||
# RERAISE 0
|
||||
self.push(val)
|
||||
self._raise_exception_variable(val)
|
||||
else:
|
||||
_exc = self.pop()
|
||||
val = self.pop()
|
||||
_tb = self.pop()
|
||||
self._raise_exception_variable(val)
|
||||
|
||||
def _isinstance_exception(self, val):
|
||||
return isinstance(
|
||||
@ -1813,7 +1884,7 @@ class InstructionTranslatorBase(
|
||||
else:
|
||||
assert len(self.stack) >= 7
|
||||
fn = self.stack[-7]
|
||||
val = self.stack[-4]
|
||||
val = self.stack[-2]
|
||||
assert self._isinstance_exception(val)
|
||||
typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined]
|
||||
tb = ConstantVariable(None)
|
||||
@ -1843,8 +1914,7 @@ class InstructionTranslatorBase(
|
||||
)
|
||||
|
||||
# 3) push the exception to the stack
|
||||
assert len(self.exn_vt_stack)
|
||||
self.push(self.exn_vt_stack[-1])
|
||||
self.push(self.exn_vt_stack.get_current_exception())
|
||||
|
||||
# 4) jump to the handler
|
||||
self.jump(exn_tab_entry)
|
||||
@ -1867,15 +1937,13 @@ class InstructionTranslatorBase(
|
||||
if len(self.block_stack):
|
||||
# base implementation - https://github.com/python/cpython/blob/3.10/Python/ceval.c#L4455
|
||||
|
||||
assert len(self.exn_vt_stack)
|
||||
exception_var = self.exn_vt_stack[-1]
|
||||
|
||||
block_stack_entry = self.block_stack.pop()
|
||||
|
||||
while block_stack_entry.inst.opname == "EXCEPT_HANDLER":
|
||||
# TODO(anijain2305) - This is not tested .. unable to create a testcase
|
||||
# https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456
|
||||
self.popn(3)
|
||||
self.exn_vt_stack.pop()
|
||||
if len(self.block_stack) == 0:
|
||||
# No handler found in this frame. Bubble the exception to the parent
|
||||
# instruction translater.
|
||||
@ -1892,15 +1960,8 @@ class InstructionTranslatorBase(
|
||||
raise raised_exception
|
||||
block_stack_entry = self.block_stack.pop()
|
||||
|
||||
if block_stack_entry.inst.opname != "SETUP_FINALLY":
|
||||
unimplemented_v2(
|
||||
gb_type="Exception raised with invalid exception handler",
|
||||
context="",
|
||||
explanation="Exception raised when top of the block stack "
|
||||
"is not exception handler (e.g. try .. with .. except). "
|
||||
f"Current TOS is {block_stack_entry.inst}",
|
||||
hints=[],
|
||||
)
|
||||
exception_var = self.exn_vt_stack.get_current_exception()
|
||||
self.exn_vt_stack.move_current_exception_to_stack()
|
||||
|
||||
# 1) pop values from the stack until it matches the stack depth
|
||||
# for the handler
|
||||
@ -1954,15 +2015,37 @@ class InstructionTranslatorBase(
|
||||
raise raised_exception
|
||||
|
||||
def PUSH_EXC_INFO(self, inst):
|
||||
# https://docs.python.org/3/library/dis.html#opcode-PUSH_EXC_INFO
|
||||
# Pops a value from the stack. Pushes the current exception to the top
|
||||
# of the stack. Pushes the value originally popped back to the stack.
|
||||
#
|
||||
# The behavior of this opcode in CPython is a bit different than what it
|
||||
# is described. It pops a value from the stack, pushes the top of the
|
||||
# exception stack to the interpreter stack and moves the
|
||||
# "current exception" to the exception stack.
|
||||
#
|
||||
# As an example, suppose the stack is in the following state:
|
||||
# + stack = [..., ConstantVariable(1), ConstantVariable(2)]
|
||||
# + current_exception = TypeError
|
||||
# + exception_stack = [ValueError]
|
||||
#
|
||||
# After PUSH_EXC_INFO is executed
|
||||
# + stack = [..., ConstantVariable(1), ValueError, ConstantVariable(2)]
|
||||
# + current_exception = None
|
||||
# + exception_stack = [ValueError, TypeError]
|
||||
|
||||
val = self.pop()
|
||||
assert len(self.exn_vt_stack)
|
||||
self.push(self.exn_vt_stack[-1])
|
||||
if len(self.exn_vt_stack) == 0:
|
||||
prev_exc = ConstantVariable(None)
|
||||
else:
|
||||
prev_exc = self.exn_vt_stack[-1]
|
||||
self.push(prev_exc)
|
||||
self.push(val)
|
||||
self.exn_vt_stack.move_current_exception_to_stack()
|
||||
|
||||
def POP_EXCEPT(self, inst):
|
||||
if sys.version_info >= (3, 11):
|
||||
val = self.pop()
|
||||
assert self._isinstance_exception(val)
|
||||
_ = self.pop()
|
||||
# This exception is handled and therefore we can clear the error indicator
|
||||
assert len(self.exn_vt_stack)
|
||||
self.exn_vt_stack.pop()
|
||||
@ -2916,6 +2999,8 @@ class InstructionTranslatorBase(
|
||||
else:
|
||||
target = inst.target
|
||||
|
||||
self.push(exit)
|
||||
|
||||
if target:
|
||||
if isinstance(self, InstructionTranslator):
|
||||
self.block_stack.append(
|
||||
@ -2924,7 +3009,6 @@ class InstructionTranslatorBase(
|
||||
else:
|
||||
self.block_stack.append(BlockStackEntry(inst, target, len(self.stack)))
|
||||
|
||||
self.push(exit)
|
||||
self.push(ctx.enter(self))
|
||||
|
||||
def append_prefix_inst(self, inst):
|
||||
@ -3125,6 +3209,7 @@ class InstructionTranslatorBase(
|
||||
export: bool,
|
||||
inline_depth: int,
|
||||
speculation_log: SpeculationLog,
|
||||
exn_vt_stack: ExceptionStack,
|
||||
distributed_state: Optional[DistributedState],
|
||||
# This determines whether to use the execution recorder.
|
||||
closure: Optional[tuple[types.CellType]] = None,
|
||||
@ -3149,7 +3234,7 @@ class InstructionTranslatorBase(
|
||||
self.kw_names = None
|
||||
self.accept_prefix_inst = True
|
||||
self.prefix_insts = []
|
||||
self.exn_vt_stack = []
|
||||
self.exn_vt_stack = exn_vt_stack
|
||||
|
||||
# Properties of the input/output code
|
||||
self.instructions: list[Instruction] = instructions
|
||||
@ -3233,6 +3318,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
export_constraints,
|
||||
frame_state,
|
||||
speculation_log: SpeculationLog,
|
||||
exn_vt_stack: ExceptionStack,
|
||||
distributed_state: Optional[DistributedState],
|
||||
) -> None:
|
||||
_step_logger()(
|
||||
@ -3266,6 +3352,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
export=export,
|
||||
inline_depth=0,
|
||||
speculation_log=speculation_log,
|
||||
exn_vt_stack=exn_vt_stack,
|
||||
distributed_state=distributed_state,
|
||||
)
|
||||
|
||||
@ -3818,9 +3905,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
self.run()
|
||||
except exc.ObservedException as e:
|
||||
msg = f"Observed exception DURING INLING {code} : {e}"
|
||||
# TODO(anijain2305) - This works but we should probably have a
|
||||
# global/central data structure for the exception stack.
|
||||
parent.exn_vt_stack.extend(self.exn_vt_stack)
|
||||
log.debug(msg)
|
||||
# bubble up the exception to the parent frame.
|
||||
raise
|
||||
@ -3895,6 +3979,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
export=parent.export,
|
||||
inline_depth=parent.inline_depth + 1,
|
||||
speculation_log=parent.speculation_log,
|
||||
exn_vt_stack=parent.exn_vt_stack,
|
||||
distributed_state=parent.distributed_state,
|
||||
)
|
||||
self.funcvar = funcvar
|
||||
|
@ -41,7 +41,6 @@ from ..bytecode_transformation import create_call_function, create_rot_n, is_gen
|
||||
from ..exc import (
|
||||
get_dynamo_observed_exception,
|
||||
handle_observed_exception,
|
||||
IncorrectUsage,
|
||||
InfiniteGeneratorError,
|
||||
ObservedException,
|
||||
ObservedGeneratorExit,
|
||||
@ -521,7 +520,6 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
||||
with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
|
||||
return tracer.inline_call_()
|
||||
except ObservedException as e:
|
||||
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
|
||||
raise e
|
||||
except InfiniteGeneratorError:
|
||||
# test/dynamo/test_misc.py::test_iterator_limit
|
||||
@ -550,9 +548,8 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
||||
|
||||
def _setup_exception(self, tx, exc):
|
||||
tracer = self._get_inline_tracer(tx)
|
||||
tracer.push(exc)
|
||||
try:
|
||||
tracer._raise_exception_variable(None)
|
||||
tracer._raise_exception_variable(exc)
|
||||
except ObservedException as e:
|
||||
# if no handler is available (i.e. user code doesn't catch it), the
|
||||
# exception is raised again.
|
||||
@ -664,19 +661,16 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
||||
# * If the generator function does not catch the passed-in exception,
|
||||
# or raises a different exception, then that exception propagates to the caller.
|
||||
|
||||
if len(args) > 1:
|
||||
raise IncorrectUsage(
|
||||
"the (type, exc, tb) signature of throw() is deprecated, "
|
||||
"use the single-arg signature instead."
|
||||
)
|
||||
|
||||
# Setup the exception table and jump target in case of try...finally
|
||||
tracer = self._get_inline_tracer(tx)
|
||||
try:
|
||||
self._setup_exception(tx, args[0])
|
||||
except ObservedException:
|
||||
# In Python 3.9, the exception is represented as a triple (typ, val, tb)
|
||||
# In such cases, we re-raise the exception object given to avoid
|
||||
# creating a new object, so that IS_OP works.
|
||||
# See: https://github.com/pytorch/pytorch/pull/146496
|
||||
self._setup_exception(tx, args[1] if len(args) == 3 else args[0])
|
||||
except ObservedException: # noqa: TRY203
|
||||
# propagate the exception back to the parent caller
|
||||
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
|
||||
raise
|
||||
|
||||
retval = self.next_variable(tx)
|
||||
@ -749,9 +743,6 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
||||
except get_dynamo_observed_exception(exc_type):
|
||||
# We should get back the exception raised before.
|
||||
pass
|
||||
except ObservedException:
|
||||
# Propagate anything else back to the parent caller
|
||||
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
|
||||
else:
|
||||
raise_observed_exception(RuntimeError, tracer)
|
||||
return retval
|
||||
|
Reference in New Issue
Block a user