Correctly propagate exception to parent tx (#146502)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146502
Approved by: https://github.com/anijain2305, https://github.com/williamwen42, https://github.com/zou3519
ghstack dependencies: #146504, #146499
This commit is contained in:
Guilherme Leobas
2025-03-11 13:33:09 +00:00
committed by PyTorch MergeBot
parent fb53e9e514
commit daff65d671
6 changed files with 431 additions and 135 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: dynamo"]
import contextlib
import sys
import unittest
@ -11,7 +12,11 @@ import torch.nn
import torch.utils.checkpoint
from torch._dynamo.bytecode_transformation import Instruction
from torch._dynamo.symbolic_convert import SpeculationLog, SpeculationLogDivergence
from torch.testing._internal.common_utils import make_dynamo_test
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
make_dynamo_test,
parametrize,
)
class CustomException(Exception):
@ -123,6 +128,33 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x)
self.assertEqual(ref, res)
@make_dynamo_test
def test_propagate_exception_inside_ctx_manager(self):
@contextlib.contextmanager
def cm():
try:
yield
except BaseException:
raise ValueError # noqa: B904
@contextlib.contextmanager
def nothing():
try:
yield
finally:
pass
z = 0
with nothing():
try:
with cm():
raise IndexError
except ValueError:
z = 1
except IndexError:
z = 2
assert z == 1
def test_exception_else(self):
def gn(x):
return torch.cos(x)
@ -145,6 +177,64 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x)
self.assertEqual(ref, res)
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
@make_dynamo_test
def test_raise_match(self):
a = AttributeError
b = BytesWarning
c = ConnectionError
d = DeprecationWarning
e = Exception
def fn(a, b):
try:
raise a
finally:
raise b
def fix_exc_context(frame_exc, new_exc, old_exc):
# slightly change from ExitStack.fix_exc_context function
while 1:
exc_context = new_exc.__context__
if exc_context is None or exc_context is old_exc:
return
if exc_context is frame_exc:
break
new_exc = exc_context
new_exc.__context__ = old_exc
@contextlib.contextmanager
def ctx():
try:
yield
finally:
frame_exc = prev_exc = sys.exc_info()
args = [(d, c), (b, a)]
for x, y in args:
try:
fn(x, y)
except BaseException:
new_exc = sys.exc_info()
fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1])
prev_exc = new_exc
try:
fixed_ctx = prev_exc[1].__context__
raise prev_exc[1]
except BaseException:
prev_exc[1].__context__ = fixed_ctx
raise
try:
with ctx():
raise e
except Exception as exc:
assert isinstance(exc, a)
assert isinstance(exc.__context__, b)
assert isinstance(exc.__context__.__context__, c)
assert isinstance(exc.__context__.__context__.__context__, d)
assert isinstance(exc.__context__.__context__.__context__.__context__, e)
# TODO(anijain2305) - does not work with fullgraph=True
def test_exception_with_another_exception2(self):
def gn(x):
@ -455,6 +545,103 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref[0], res[0])
self.assertEqual(ref[1], res[1])
@make_dynamo_test
def test_reraise_first_exc(self):
def fn():
try:
raise ZeroDivisionError
except ZeroDivisionError:
try:
raise ValueError
except ValueError:
pass
raise
try:
fn()
except ZeroDivisionError:
pass
assert sys.exc_info()[0] is None
@make_dynamo_test
def test_ensure_exception_is_active_after_try_except_block(self):
try:
try:
raise ZeroDivisionError
except ZeroDivisionError:
for exc in (KeyError, IndexError):
try:
raise exc
except exc:
pass
raise
except ZeroDivisionError:
pass
assert sys.exc_info()[0] is None
@make_dynamo_test
def test_ensure_exception_is_active_inside_try_except_block(self):
try:
try:
raise ZeroDivisionError
except ZeroDivisionError:
for exc in (KeyError, IndexError):
try:
raise exc
except exc as e:
assert isinstance(e.__context__, ZeroDivisionError)
raise
except ZeroDivisionError:
pass
assert sys.exc_info()[0] is None
@make_dynamo_test
def test_handle_all_exceptions(self):
def cm():
try:
yield 1
except ValueError:
try:
raise TypeError
finally:
pass
try:
gen = cm()
next(gen)
gen.throw(ValueError)
except TypeError:
pass
assert sys.exc_info()[0] is None
@make_dynamo_test
def test_reraise(self):
try:
try:
raise ValueError
except ValueError: # noqa: TRY203
raise
except ValueError:
pass
assert sys.exc_info()[0] is None
@make_dynamo_test
def test_raise_finally_simple(self):
def fn():
try:
raise ValueError
except ValueError:
try:
raise TypeError
finally:
pass
try:
fn()
except TypeError:
pass
assert sys.exc_info()[0] is None
def test_reconstruct___context__(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
@ -574,6 +761,54 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
with self.assertRaisesRegex(TypeError, "exception cause must be"):
fn(t, e)
@parametrize(
"ex",
[TypeError, CustomException],
name_fn=lambda x: x.__name__,
)
@make_dynamo_test
def test_set___cause__(self, ex):
def fn():
try:
raise ex
except ex:
raise TypeError from None
try:
fn()
except TypeError as e:
assert isinstance(e.__context__, ex)
assert e.__cause__ is None
assert e.__suppress_context__ is True
@parametrize(
"ex",
[RuntimeError, CustomException],
name_fn=lambda x: x.__name__,
)
@make_dynamo_test
def test_set___cause___error(self, ex):
def fn():
try:
raise ex
except Exception as e:
e.__cause__ = 2
raise
z = 0
try:
fn()
except TypeError as e:
z = 1
assert e.args == (
"exception cause must be None or derive from BaseException",
)
except Exception:
raise AssertionError from None
assert z == 1
def test_user_defined_exception_variable(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
@ -852,6 +1087,9 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIs(a.__context__, c)
instantiate_parametrized_tests(ExceptionTests)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -7,7 +7,7 @@ from collections import OrderedDict
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.exc import InternalTorchDynamoError, Unsupported
from torch._dynamo.exc import Unsupported
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import (
@ -1239,7 +1239,6 @@ class TestGeneratorThrow(GeneratorTestsBase):
y = self._compile_check(fn, (t,))
self.assertEqual(y, t.sin() + t.cos())
@unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE")
def test_throw_with_finally(self):
z = 0
@ -1296,24 +1295,6 @@ class TestGeneratorThrow(GeneratorTestsBase):
self.assertEqual(y, t.sin() + t.cos())
self.assertEqual(z, 101)
def test_throw_three_arguments(self):
def whoo(t):
try:
yield t.sin()
except ValueError:
yield t.cos()
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
gen = whoo(t)
a = next(gen)
b = gen.throw(ValueError, "Error", None)
return a + b
t = torch.randn(2)
with self.assertRaises(InternalTorchDynamoError):
fn(t)
def test_throw_no_yield_after_throw(self):
z = 0
@ -1420,7 +1401,6 @@ class TestGeneratorThrow(GeneratorTestsBase):
with self.assertRaises(Unsupported):
fn(t)
@unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE")
def test_throw_try_except_finally(self):
z = 0
@ -1638,7 +1618,6 @@ class GeneratorThrowCpythonTests(GeneratorTestsBase):
self._compile_check(fn)
@unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE")
def test_exception_context_with_yield_inside_generator(self):
# Check that the context is also available from inside the generator
# with yield, as opposed to outside.

View File

@ -118,6 +118,7 @@ from .replay_record import ExecutionRecord
from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
from .symbolic_convert import (
DistributedState,
ExceptionStack,
InstructionTranslator,
LocalState,
SpeculationLog,
@ -689,6 +690,7 @@ def _compile(
nonlocal output
nonlocal tracer
speculation_log.restart()
exn_vt_stack = ExceptionStack()
tracer = InstructionTranslator(
instructions,
code,
@ -704,6 +706,7 @@ def _compile(
export_constraints,
frame_state=frame_state,
speculation_log=speculation_log,
exn_vt_stack=exn_vt_stack,
distributed_state=distributed_state,
)

View File

@ -362,7 +362,7 @@ def raise_observed_exception(
# CPython here raises an exception. Since there is no python code, we have to manually setup the exception
# stack and raise the exception.
exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type]
tx.exn_vt_stack.append(exception_vt)
tx.exn_vt_stack.set_current_exception(exception_vt)
raise observed_exception_map[exc_type]
@ -391,7 +391,7 @@ def handle_observed_exception(tx: Any) -> None:
#
# Fortunately this translates to a simple pop from the exn_vt_stack
tx.exn_vt_stack.pop()
tx.exn_vt_stack.clear_current_exception()
# These exceptions are ok to fallback to eager/graph_break.

View File

@ -44,7 +44,7 @@ import traceback
import types
import typing
import weakref
from typing import Any, Callable, cast, Optional, Union
from typing import Any, Callable, cast, NoReturn, Optional, Union
from unittest.mock import patch
import torch
@ -383,9 +383,11 @@ class BlockStackEntry:
and hasattr(self.with_context, "target_values")
and self.with_context.target_values
):
return ReenterWith(self.stack_index, tuple(self.with_context.target_values))
return ReenterWith(
self.stack_index - 1, tuple(self.with_context.target_values)
)
else:
return ReenterWith(self.stack_index)
return ReenterWith(self.stack_index - 1)
def exit(self, tx, is_graph_break):
assert self.with_context is not None
@ -956,6 +958,100 @@ class BytecodeDistpatchTableMeta(type):
cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)]
@dataclasses.dataclass
class ExceptionStack:
"""
Exception stack that it is shared among all InstructionTranslator instances
"""
# Exception handling in CPython is a bit confusing and some of the bytecode
# have a slightly different behavior than what is is documented. While reading
# the documentation, is important to notice that the terms "current exception"
# and "stack" sometimes refers to a C variable with the same name and the
# exception stack, respectively.
#
# The lifetime of an exception is (Python 3.11+):
# + tx._raise_exception_variable(...) := sets the current_exception variable
# + PUSH_EXC_INFO := pushes the current_exception to the *exception stack*
# + POP_EXCEPT := pops TOS from the *exception stack*
_exc_stack: list[VariableTracker] = dataclasses.field(default_factory=list)
_current_exception: Optional[VariableTracker] = dataclasses.field(default=None)
def clear_current_exception(self):
self._current_exception = None
def set_current_exception(self, val):
self._set_context_and_break_context_reference_cycle(val)
self._current_exception = val
def move_current_exception_to_stack(self):
assert self._current_exception is not None
self.append(self._current_exception)
self.clear_current_exception()
def get_current_exception(self):
assert self._current_exception is not None
return self._current_exception
def _set_context_recursive(self, val, prev_idx):
if (ctx := val.__context__) and type(ctx) is not ConstantVariable:
return val
if len(self._exc_stack) + prev_idx > 0:
prev = self._exc_stack[prev_idx]
self._set_context_recursive(prev, prev_idx - 1)
val.set_context(prev)
return val
def _break_context_reference_cycle(self, val):
# See test_exceptions::test_raise_does_not_create_context_chain_cycle
# Based on https://github.com/python/cpython/blob/e635bf2e49797ecb976ce45a67fce2201a25ca68/Python/errors.c#L207-L228
# As noted on CPython, this is O(chain length) but the context chains
# are usually very small
o = slow_o = val
slow_update_toggle = False # floyd's algorithm for detecting cycle
while True:
context = o.__context__
if type(context) is ConstantVariable: # context not set
break
if context is val:
o.set_context(ConstantVariable(None))
break
o = context
if o is slow_o:
# pre-existing cycle - all exceptions on the path were
# visited and checked
break
if slow_update_toggle:
slow_o = slow_o.__context__ # visited all exceptions
slow_update_toggle = not slow_update_toggle
def _set_context_and_break_context_reference_cycle(self, val):
# set Exception.__context__
self._set_context_recursive(val, len(self._exc_stack) - 1)
self._break_context_reference_cycle(val)
def pop(self):
return self._exc_stack.pop()
def append(self, val):
self._exc_stack.append(val)
def __len__(self):
return len(self._exc_stack)
def __getitem__(self, index):
return self._exc_stack[index]
def __str__(self):
return f"{self._exc_stack=} - {self._current_exception=}"
__repr__ = __str__
class InstructionTranslatorBase(
metaclass=BytecodeDistpatchTableMeta,
):
@ -975,7 +1071,7 @@ class InstructionTranslatorBase(
inconsistent_side_effects: bool
current_speculation: Optional[SpeculationEntry]
dispatch_table: list[Any]
exn_vt_stack: list[VariableTracker]
exn_vt_stack: ExceptionStack
exec_recorder: Optional[ExecutionRecorder]
strict_checks_fn: Optional[Callable[[VariableTracker], bool]]
start_point: Optional[int]
@ -1655,43 +1751,7 @@ class InstructionTranslatorBase(
self.push(ConstantVariable.create(None))
self.jump(inst)
def _raise_exception_variable(self, inst):
def set_context_recursive(val, prev_idx):
if (ctx := val.__context__) and type(ctx) is not ConstantVariable:
return val
if len(self.exn_vt_stack) + prev_idx > 0:
prev = self.exn_vt_stack[prev_idx]
set_context_recursive(prev, prev_idx - 1)
val.set_context(prev)
return val
def break_context_reference_cycle(val):
# See test_exceptions::test_raise_does_not_create_context_chain_cycle
# Based on https://github.com/python/cpython/blob/e635bf2e49797ecb976ce45a67fce2201a25ca68/Python/errors.c#L207-L228
# As noted on CPython, this is O(chain length) but the context chains
# are usually very small
o = slow_o = val
slow_update_toggle = False # floyd's algorithm for detecting cycle
while True:
context = o.__context__
if type(context) is ConstantVariable: # context not set
break
if context is val:
o.set_context(ConstantVariable(None))
break
o = context
if o is slow_o:
# pre-existing cycle - all exceptions on the path were
# visited and checked
break
if slow_update_toggle:
slow_o = slow_o.__context__ # visited all exceptions
slow_update_toggle = not slow_update_toggle
val = self.pop()
def _raise_exception_variable(self, val) -> NoReturn:
# User can raise exception in 2 ways
# 1) raise exception type - raise NotImplementedError
# 2) raise execption instance - raise NotImplemetedError("foo")
@ -1705,6 +1765,7 @@ class InstructionTranslatorBase(
val = val.call_function(self, [], {}) # type: ignore[arg-type]
# Handle https://peps.python.org/pep-0479/
# CPython 3.12+ has a specific bytecode instruction (CALL_INTRINSIC_1 3) for this
if (
is_generator(self.f_code)
and isinstance(val, variables.ExceptionVariable)
@ -1712,12 +1773,8 @@ class InstructionTranslatorBase(
):
val = variables.BuiltinVariable(RuntimeError).call_function(self, [], {}) # type: ignore[arg-type]
# set Exception.__context__
set_context_recursive(val, len(self.exn_vt_stack) - 1)
break_context_reference_cycle(val)
# Save the exception in a global data structure
self.exn_vt_stack.append(val)
self.exn_vt_stack.set_current_exception(val)
# 2) when user raises exception instance
if self._isinstance_exception(val):
@ -1732,26 +1789,29 @@ class InstructionTranslatorBase(
def RAISE_VARARGS(self, inst):
if inst.arg == 0:
# duplicate the top of the stack and re-raise it
if sys.version_info < (3, 11):
unimplemented_v2(
gb_type="Re-raise with no arguments",
context="",
explanation="Dynamo does not support re-raising the previous exception "
"in Python < 3.11 (empty `raise`)",
hints=[],
)
assert self._isinstance_exception(self.stack[-1])
self.stack.append(self.stack[-1])
self._raise_exception_variable(inst)
# re-raise the previous exception. Here CPython refers to the exception
# on top of the exception stack
assert len(self.exn_vt_stack)
val = self.exn_vt_stack[-1]
assert self._isinstance_exception(val), val
self._raise_exception_variable(val)
elif inst.arg == 1:
self._raise_exception_variable(inst)
# raise TOS
val = self.stack[-1]
self._raise_exception_variable(val)
else:
# Support raise .. from None ... Dynamo does not track __cause__ and other attributes of exception. So we
# ignore `from None` part.
# raise .. from None
from_vt = self.pop()
if isinstance(from_vt, ConstantVariable) and from_vt.value is None:
self._raise_exception_variable(inst)
val = self.pop()
try:
self._raise_exception_variable(val)
finally:
# Update __cause__/__supppress_context__ in the raised exception
curr_exc = self.exn_vt_stack.get_current_exception()
curr_exc.call_setattr(
self, ConstantVariable("__cause__"), ConstantVariable(None)
)
unimplemented_v2(
gb_type="Re-raise with 2 arguments",
context=str(from_vt),
@ -1774,16 +1834,27 @@ class InstructionTranslatorBase(
self.RERAISE(inst)
def RERAISE(self, inst):
# https://docs.python.org/3/library/dis.html#opcode-RERAISE
# Re-raises the exception currently on top of the stack. If oparg is
# non-zero, pops an additional value from the stack which is used to
# set f_lasti of the current frame.
if sys.version_info >= (3, 11):
# RERAISE is currently supported in a narrow case of `raise ... from None`
self._raise_exception_variable(inst)
unimplemented_v2(
gb_type="RERAISE in Python < 3.11",
context="",
explanation="RERAISE bytecode (https://docs.python.org/3.10/library/dis.html#opcode-RERAISE) "
"not supported in Python < 3.11. This bytecode is generated by try/with blocks.",
hints=[],
)
val = self.pop()
if inst.argval:
# RERAISE 1
_ = self.pop()
self._raise_exception_variable(val)
else:
# RERAISE 0
self.push(val)
self._raise_exception_variable(val)
else:
_exc = self.pop()
val = self.pop()
_tb = self.pop()
self._raise_exception_variable(val)
def _isinstance_exception(self, val):
return isinstance(
@ -1813,7 +1884,7 @@ class InstructionTranslatorBase(
else:
assert len(self.stack) >= 7
fn = self.stack[-7]
val = self.stack[-4]
val = self.stack[-2]
assert self._isinstance_exception(val)
typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined]
tb = ConstantVariable(None)
@ -1843,8 +1914,7 @@ class InstructionTranslatorBase(
)
# 3) push the exception to the stack
assert len(self.exn_vt_stack)
self.push(self.exn_vt_stack[-1])
self.push(self.exn_vt_stack.get_current_exception())
# 4) jump to the handler
self.jump(exn_tab_entry)
@ -1867,15 +1937,13 @@ class InstructionTranslatorBase(
if len(self.block_stack):
# base implementation - https://github.com/python/cpython/blob/3.10/Python/ceval.c#L4455
assert len(self.exn_vt_stack)
exception_var = self.exn_vt_stack[-1]
block_stack_entry = self.block_stack.pop()
while block_stack_entry.inst.opname == "EXCEPT_HANDLER":
# TODO(anijain2305) - This is not tested .. unable to create a testcase
# https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456
self.popn(3)
self.exn_vt_stack.pop()
if len(self.block_stack) == 0:
# No handler found in this frame. Bubble the exception to the parent
# instruction translater.
@ -1892,15 +1960,8 @@ class InstructionTranslatorBase(
raise raised_exception
block_stack_entry = self.block_stack.pop()
if block_stack_entry.inst.opname != "SETUP_FINALLY":
unimplemented_v2(
gb_type="Exception raised with invalid exception handler",
context="",
explanation="Exception raised when top of the block stack "
"is not exception handler (e.g. try .. with .. except). "
f"Current TOS is {block_stack_entry.inst}",
hints=[],
)
exception_var = self.exn_vt_stack.get_current_exception()
self.exn_vt_stack.move_current_exception_to_stack()
# 1) pop values from the stack until it matches the stack depth
# for the handler
@ -1954,15 +2015,37 @@ class InstructionTranslatorBase(
raise raised_exception
def PUSH_EXC_INFO(self, inst):
# https://docs.python.org/3/library/dis.html#opcode-PUSH_EXC_INFO
# Pops a value from the stack. Pushes the current exception to the top
# of the stack. Pushes the value originally popped back to the stack.
#
# The behavior of this opcode in CPython is a bit different than what it
# is described. It pops a value from the stack, pushes the top of the
# exception stack to the interpreter stack and moves the
# "current exception" to the exception stack.
#
# As an example, suppose the stack is in the following state:
# + stack = [..., ConstantVariable(1), ConstantVariable(2)]
# + current_exception = TypeError
# + exception_stack = [ValueError]
#
# After PUSH_EXC_INFO is executed
# + stack = [..., ConstantVariable(1), ValueError, ConstantVariable(2)]
# + current_exception = None
# + exception_stack = [ValueError, TypeError]
val = self.pop()
assert len(self.exn_vt_stack)
self.push(self.exn_vt_stack[-1])
if len(self.exn_vt_stack) == 0:
prev_exc = ConstantVariable(None)
else:
prev_exc = self.exn_vt_stack[-1]
self.push(prev_exc)
self.push(val)
self.exn_vt_stack.move_current_exception_to_stack()
def POP_EXCEPT(self, inst):
if sys.version_info >= (3, 11):
val = self.pop()
assert self._isinstance_exception(val)
_ = self.pop()
# This exception is handled and therefore we can clear the error indicator
assert len(self.exn_vt_stack)
self.exn_vt_stack.pop()
@ -2916,6 +2999,8 @@ class InstructionTranslatorBase(
else:
target = inst.target
self.push(exit)
if target:
if isinstance(self, InstructionTranslator):
self.block_stack.append(
@ -2924,7 +3009,6 @@ class InstructionTranslatorBase(
else:
self.block_stack.append(BlockStackEntry(inst, target, len(self.stack)))
self.push(exit)
self.push(ctx.enter(self))
def append_prefix_inst(self, inst):
@ -3125,6 +3209,7 @@ class InstructionTranslatorBase(
export: bool,
inline_depth: int,
speculation_log: SpeculationLog,
exn_vt_stack: ExceptionStack,
distributed_state: Optional[DistributedState],
# This determines whether to use the execution recorder.
closure: Optional[tuple[types.CellType]] = None,
@ -3149,7 +3234,7 @@ class InstructionTranslatorBase(
self.kw_names = None
self.accept_prefix_inst = True
self.prefix_insts = []
self.exn_vt_stack = []
self.exn_vt_stack = exn_vt_stack
# Properties of the input/output code
self.instructions: list[Instruction] = instructions
@ -3233,6 +3318,7 @@ class InstructionTranslator(InstructionTranslatorBase):
export_constraints,
frame_state,
speculation_log: SpeculationLog,
exn_vt_stack: ExceptionStack,
distributed_state: Optional[DistributedState],
) -> None:
_step_logger()(
@ -3266,6 +3352,7 @@ class InstructionTranslator(InstructionTranslatorBase):
export=export,
inline_depth=0,
speculation_log=speculation_log,
exn_vt_stack=exn_vt_stack,
distributed_state=distributed_state,
)
@ -3818,9 +3905,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
self.run()
except exc.ObservedException as e:
msg = f"Observed exception DURING INLING {code} : {e}"
# TODO(anijain2305) - This works but we should probably have a
# global/central data structure for the exception stack.
parent.exn_vt_stack.extend(self.exn_vt_stack)
log.debug(msg)
# bubble up the exception to the parent frame.
raise
@ -3895,6 +3979,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
export=parent.export,
inline_depth=parent.inline_depth + 1,
speculation_log=parent.speculation_log,
exn_vt_stack=parent.exn_vt_stack,
distributed_state=parent.distributed_state,
)
self.funcvar = funcvar

View File

@ -41,7 +41,6 @@ from ..bytecode_transformation import create_call_function, create_rot_n, is_gen
from ..exc import (
get_dynamo_observed_exception,
handle_observed_exception,
IncorrectUsage,
InfiniteGeneratorError,
ObservedException,
ObservedGeneratorExit,
@ -521,7 +520,6 @@ class LocalGeneratorObjectVariable(VariableTracker):
with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
return tracer.inline_call_()
except ObservedException as e:
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
raise e
except InfiniteGeneratorError:
# test/dynamo/test_misc.py::test_iterator_limit
@ -550,9 +548,8 @@ class LocalGeneratorObjectVariable(VariableTracker):
def _setup_exception(self, tx, exc):
tracer = self._get_inline_tracer(tx)
tracer.push(exc)
try:
tracer._raise_exception_variable(None)
tracer._raise_exception_variable(exc)
except ObservedException as e:
# if no handler is available (i.e. user code doesn't catch it), the
# exception is raised again.
@ -664,19 +661,16 @@ class LocalGeneratorObjectVariable(VariableTracker):
# * If the generator function does not catch the passed-in exception,
# or raises a different exception, then that exception propagates to the caller.
if len(args) > 1:
raise IncorrectUsage(
"the (type, exc, tb) signature of throw() is deprecated, "
"use the single-arg signature instead."
)
# Setup the exception table and jump target in case of try...finally
tracer = self._get_inline_tracer(tx)
try:
self._setup_exception(tx, args[0])
except ObservedException:
# In Python 3.9, the exception is represented as a triple (typ, val, tb)
# In such cases, we re-raise the exception object given to avoid
# creating a new object, so that IS_OP works.
# See: https://github.com/pytorch/pytorch/pull/146496
self._setup_exception(tx, args[1] if len(args) == 3 else args[0])
except ObservedException: # noqa: TRY203
# propagate the exception back to the parent caller
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
raise
retval = self.next_variable(tx)
@ -749,9 +743,6 @@ class LocalGeneratorObjectVariable(VariableTracker):
except get_dynamo_observed_exception(exc_type):
# We should get back the exception raised before.
pass
except ObservedException:
# Propagate anything else back to the parent caller
tx.exn_vt_stack.extend(tracer.exn_vt_stack)
else:
raise_observed_exception(RuntimeError, tracer)
return retval