mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Correctly propagate exception to parent tx (#146502)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146502 Approved by: https://github.com/anijain2305, https://github.com/williamwen42, https://github.com/zou3519 ghstack dependencies: #146504, #146499
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							fb53e9e514
						
					
				
				
					commit
					daff65d671
				
			@ -1,5 +1,6 @@
 | 
			
		||||
# Owner(s): ["module: dynamo"]
 | 
			
		||||
 | 
			
		||||
import contextlib
 | 
			
		||||
import sys
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
@ -11,7 +12,11 @@ import torch.nn
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from torch._dynamo.bytecode_transformation import Instruction
 | 
			
		||||
from torch._dynamo.symbolic_convert import SpeculationLog, SpeculationLogDivergence
 | 
			
		||||
from torch.testing._internal.common_utils import make_dynamo_test
 | 
			
		||||
from torch.testing._internal.common_utils import (
 | 
			
		||||
    instantiate_parametrized_tests,
 | 
			
		||||
    make_dynamo_test,
 | 
			
		||||
    parametrize,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CustomException(Exception):
 | 
			
		||||
@ -123,6 +128,33 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
        res = opt_fn(x)
 | 
			
		||||
        self.assertEqual(ref, res)
 | 
			
		||||
 | 
			
		||||
    @make_dynamo_test
 | 
			
		||||
    def test_propagate_exception_inside_ctx_manager(self):
 | 
			
		||||
        @contextlib.contextmanager
 | 
			
		||||
        def cm():
 | 
			
		||||
            try:
 | 
			
		||||
                yield
 | 
			
		||||
            except BaseException:
 | 
			
		||||
                raise ValueError  # noqa: B904
 | 
			
		||||
 | 
			
		||||
        @contextlib.contextmanager
 | 
			
		||||
        def nothing():
 | 
			
		||||
            try:
 | 
			
		||||
                yield
 | 
			
		||||
            finally:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        z = 0
 | 
			
		||||
        with nothing():
 | 
			
		||||
            try:
 | 
			
		||||
                with cm():
 | 
			
		||||
                    raise IndexError
 | 
			
		||||
            except ValueError:
 | 
			
		||||
                z = 1
 | 
			
		||||
            except IndexError:
 | 
			
		||||
                z = 2
 | 
			
		||||
            assert z == 1
 | 
			
		||||
 | 
			
		||||
    def test_exception_else(self):
 | 
			
		||||
        def gn(x):
 | 
			
		||||
            return torch.cos(x)
 | 
			
		||||
@ -145,6 +177,64 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
        res = opt_fn(x)
 | 
			
		||||
        self.assertEqual(ref, res)
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
 | 
			
		||||
    @make_dynamo_test
 | 
			
		||||
    def test_raise_match(self):
 | 
			
		||||
        a = AttributeError
 | 
			
		||||
        b = BytesWarning
 | 
			
		||||
        c = ConnectionError
 | 
			
		||||
        d = DeprecationWarning
 | 
			
		||||
        e = Exception
 | 
			
		||||
 | 
			
		||||
        def fn(a, b):
 | 
			
		||||
            try:
 | 
			
		||||
                raise a
 | 
			
		||||
            finally:
 | 
			
		||||
                raise b
 | 
			
		||||
 | 
			
		||||
        def fix_exc_context(frame_exc, new_exc, old_exc):
 | 
			
		||||
            # slightly change from ExitStack.fix_exc_context function
 | 
			
		||||
            while 1:
 | 
			
		||||
                exc_context = new_exc.__context__
 | 
			
		||||
                if exc_context is None or exc_context is old_exc:
 | 
			
		||||
                    return
 | 
			
		||||
                if exc_context is frame_exc:
 | 
			
		||||
                    break
 | 
			
		||||
                new_exc = exc_context
 | 
			
		||||
            new_exc.__context__ = old_exc
 | 
			
		||||
 | 
			
		||||
        @contextlib.contextmanager
 | 
			
		||||
        def ctx():
 | 
			
		||||
            try:
 | 
			
		||||
                yield
 | 
			
		||||
            finally:
 | 
			
		||||
                frame_exc = prev_exc = sys.exc_info()
 | 
			
		||||
                args = [(d, c), (b, a)]
 | 
			
		||||
                for x, y in args:
 | 
			
		||||
                    try:
 | 
			
		||||
                        fn(x, y)
 | 
			
		||||
                    except BaseException:
 | 
			
		||||
                        new_exc = sys.exc_info()
 | 
			
		||||
                        fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1])
 | 
			
		||||
                        prev_exc = new_exc
 | 
			
		||||
 | 
			
		||||
                try:
 | 
			
		||||
                    fixed_ctx = prev_exc[1].__context__
 | 
			
		||||
                    raise prev_exc[1]
 | 
			
		||||
                except BaseException:
 | 
			
		||||
                    prev_exc[1].__context__ = fixed_ctx
 | 
			
		||||
                    raise
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            with ctx():
 | 
			
		||||
                raise e
 | 
			
		||||
        except Exception as exc:
 | 
			
		||||
            assert isinstance(exc, a)
 | 
			
		||||
            assert isinstance(exc.__context__, b)
 | 
			
		||||
            assert isinstance(exc.__context__.__context__, c)
 | 
			
		||||
            assert isinstance(exc.__context__.__context__.__context__, d)
 | 
			
		||||
            assert isinstance(exc.__context__.__context__.__context__.__context__, e)
 | 
			
		||||
 | 
			
		||||
    # TODO(anijain2305) - does not work with fullgraph=True
 | 
			
		||||
    def test_exception_with_another_exception2(self):
 | 
			
		||||
        def gn(x):
 | 
			
		||||
@ -455,6 +545,103 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
        self.assertEqual(ref[0], res[0])
 | 
			
		||||
        self.assertEqual(ref[1], res[1])
 | 
			
		||||
 | 
			
		||||
    @make_dynamo_test
 | 
			
		||||
    def test_reraise_first_exc(self):
 | 
			
		||||
        def fn():
 | 
			
		||||
            try:
 | 
			
		||||
                raise ZeroDivisionError
 | 
			
		||||
            except ZeroDivisionError:
 | 
			
		||||
                try:
 | 
			
		||||
                    raise ValueError
 | 
			
		||||
                except ValueError:
 | 
			
		||||
                    pass
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            fn()
 | 
			
		||||
        except ZeroDivisionError:
 | 
			
		||||
            pass
 | 
			
		||||
        assert sys.exc_info()[0] is None
 | 
			
		||||
 | 
			
		||||
    @make_dynamo_test
 | 
			
		||||
    def test_ensure_exception_is_active_after_try_except_block(self):
 | 
			
		||||
        try:
 | 
			
		||||
            try:
 | 
			
		||||
                raise ZeroDivisionError
 | 
			
		||||
            except ZeroDivisionError:
 | 
			
		||||
                for exc in (KeyError, IndexError):
 | 
			
		||||
                    try:
 | 
			
		||||
                        raise exc
 | 
			
		||||
                    except exc:
 | 
			
		||||
                        pass
 | 
			
		||||
                raise
 | 
			
		||||
        except ZeroDivisionError:
 | 
			
		||||
            pass
 | 
			
		||||
        assert sys.exc_info()[0] is None
 | 
			
		||||
 | 
			
		||||
    @make_dynamo_test
 | 
			
		||||
    def test_ensure_exception_is_active_inside_try_except_block(self):
 | 
			
		||||
        try:
 | 
			
		||||
            try:
 | 
			
		||||
                raise ZeroDivisionError
 | 
			
		||||
            except ZeroDivisionError:
 | 
			
		||||
                for exc in (KeyError, IndexError):
 | 
			
		||||
                    try:
 | 
			
		||||
                        raise exc
 | 
			
		||||
                    except exc as e:
 | 
			
		||||
                        assert isinstance(e.__context__, ZeroDivisionError)
 | 
			
		||||
                raise
 | 
			
		||||
        except ZeroDivisionError:
 | 
			
		||||
            pass
 | 
			
		||||
        assert sys.exc_info()[0] is None
 | 
			
		||||
 | 
			
		||||
    @make_dynamo_test
 | 
			
		||||
    def test_handle_all_exceptions(self):
 | 
			
		||||
        def cm():
 | 
			
		||||
            try:
 | 
			
		||||
                yield 1
 | 
			
		||||
            except ValueError:
 | 
			
		||||
                try:
 | 
			
		||||
                    raise TypeError
 | 
			
		||||
                finally:
 | 
			
		||||
                    pass
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            gen = cm()
 | 
			
		||||
            next(gen)
 | 
			
		||||
            gen.throw(ValueError)
 | 
			
		||||
        except TypeError:
 | 
			
		||||
            pass
 | 
			
		||||
        assert sys.exc_info()[0] is None
 | 
			
		||||
 | 
			
		||||
    @make_dynamo_test
 | 
			
		||||
    def test_reraise(self):
 | 
			
		||||
        try:
 | 
			
		||||
            try:
 | 
			
		||||
                raise ValueError
 | 
			
		||||
            except ValueError:  # noqa: TRY203
 | 
			
		||||
                raise
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            pass
 | 
			
		||||
        assert sys.exc_info()[0] is None
 | 
			
		||||
 | 
			
		||||
    @make_dynamo_test
 | 
			
		||||
    def test_raise_finally_simple(self):
 | 
			
		||||
        def fn():
 | 
			
		||||
            try:
 | 
			
		||||
                raise ValueError
 | 
			
		||||
            except ValueError:
 | 
			
		||||
                try:
 | 
			
		||||
                    raise TypeError
 | 
			
		||||
                finally:
 | 
			
		||||
                    pass
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            fn()
 | 
			
		||||
        except TypeError:
 | 
			
		||||
            pass
 | 
			
		||||
        assert sys.exc_info()[0] is None
 | 
			
		||||
 | 
			
		||||
    def test_reconstruct___context__(self):
 | 
			
		||||
        @torch.compile(backend="eager", fullgraph=True)
 | 
			
		||||
        def fn(t):
 | 
			
		||||
@ -574,6 +761,54 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
        with self.assertRaisesRegex(TypeError, "exception cause must be"):
 | 
			
		||||
            fn(t, e)
 | 
			
		||||
 | 
			
		||||
    @parametrize(
 | 
			
		||||
        "ex",
 | 
			
		||||
        [TypeError, CustomException],
 | 
			
		||||
        name_fn=lambda x: x.__name__,
 | 
			
		||||
    )
 | 
			
		||||
    @make_dynamo_test
 | 
			
		||||
    def test_set___cause__(self, ex):
 | 
			
		||||
        def fn():
 | 
			
		||||
            try:
 | 
			
		||||
                raise ex
 | 
			
		||||
            except ex:
 | 
			
		||||
                raise TypeError from None
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            fn()
 | 
			
		||||
        except TypeError as e:
 | 
			
		||||
            assert isinstance(e.__context__, ex)
 | 
			
		||||
            assert e.__cause__ is None
 | 
			
		||||
            assert e.__suppress_context__ is True
 | 
			
		||||
 | 
			
		||||
    @parametrize(
 | 
			
		||||
        "ex",
 | 
			
		||||
        [RuntimeError, CustomException],
 | 
			
		||||
        name_fn=lambda x: x.__name__,
 | 
			
		||||
    )
 | 
			
		||||
    @make_dynamo_test
 | 
			
		||||
    def test_set___cause___error(self, ex):
 | 
			
		||||
        def fn():
 | 
			
		||||
            try:
 | 
			
		||||
                raise ex
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                e.__cause__ = 2
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
        z = 0
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            fn()
 | 
			
		||||
        except TypeError as e:
 | 
			
		||||
            z = 1
 | 
			
		||||
            assert e.args == (
 | 
			
		||||
                "exception cause must be None or derive from BaseException",
 | 
			
		||||
            )
 | 
			
		||||
        except Exception:
 | 
			
		||||
            raise AssertionError from None
 | 
			
		||||
 | 
			
		||||
        assert z == 1
 | 
			
		||||
 | 
			
		||||
    def test_user_defined_exception_variable(self):
 | 
			
		||||
        @torch.compile(backend="eager", fullgraph=True)
 | 
			
		||||
        def fn(t):
 | 
			
		||||
@ -852,6 +1087,9 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
        self.assertIs(a.__context__, c)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
instantiate_parametrized_tests(ExceptionTests)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    from torch._dynamo.test_case import run_tests
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,7 @@ from collections import OrderedDict
 | 
			
		||||
import torch
 | 
			
		||||
import torch._dynamo.test_case
 | 
			
		||||
import torch._dynamo.testing
 | 
			
		||||
from torch._dynamo.exc import InternalTorchDynamoError, Unsupported
 | 
			
		||||
from torch._dynamo.exc import Unsupported
 | 
			
		||||
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
 | 
			
		||||
from torch._dynamo.utils import counters
 | 
			
		||||
from torch.testing._internal.common_utils import (
 | 
			
		||||
@ -1239,7 +1239,6 @@ class TestGeneratorThrow(GeneratorTestsBase):
 | 
			
		||||
        y = self._compile_check(fn, (t,))
 | 
			
		||||
        self.assertEqual(y, t.sin() + t.cos())
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE")
 | 
			
		||||
    def test_throw_with_finally(self):
 | 
			
		||||
        z = 0
 | 
			
		||||
 | 
			
		||||
@ -1296,24 +1295,6 @@ class TestGeneratorThrow(GeneratorTestsBase):
 | 
			
		||||
        self.assertEqual(y, t.sin() + t.cos())
 | 
			
		||||
        self.assertEqual(z, 101)
 | 
			
		||||
 | 
			
		||||
    def test_throw_three_arguments(self):
 | 
			
		||||
        def whoo(t):
 | 
			
		||||
            try:
 | 
			
		||||
                yield t.sin()
 | 
			
		||||
            except ValueError:
 | 
			
		||||
                yield t.cos()
 | 
			
		||||
 | 
			
		||||
        @torch.compile(backend="eager", fullgraph=True)
 | 
			
		||||
        def fn(t):
 | 
			
		||||
            gen = whoo(t)
 | 
			
		||||
            a = next(gen)
 | 
			
		||||
            b = gen.throw(ValueError, "Error", None)
 | 
			
		||||
            return a + b
 | 
			
		||||
 | 
			
		||||
        t = torch.randn(2)
 | 
			
		||||
        with self.assertRaises(InternalTorchDynamoError):
 | 
			
		||||
            fn(t)
 | 
			
		||||
 | 
			
		||||
    def test_throw_no_yield_after_throw(self):
 | 
			
		||||
        z = 0
 | 
			
		||||
 | 
			
		||||
@ -1420,7 +1401,6 @@ class TestGeneratorThrow(GeneratorTestsBase):
 | 
			
		||||
        with self.assertRaises(Unsupported):
 | 
			
		||||
            fn(t)
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE")
 | 
			
		||||
    def test_throw_try_except_finally(self):
 | 
			
		||||
        z = 0
 | 
			
		||||
 | 
			
		||||
@ -1638,7 +1618,6 @@ class GeneratorThrowCpythonTests(GeneratorTestsBase):
 | 
			
		||||
 | 
			
		||||
        self._compile_check(fn)
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE")
 | 
			
		||||
    def test_exception_context_with_yield_inside_generator(self):
 | 
			
		||||
        # Check that the context is also available from inside the generator
 | 
			
		||||
        # with yield, as opposed to outside.
 | 
			
		||||
 | 
			
		||||
@ -118,6 +118,7 @@ from .replay_record import ExecutionRecord
 | 
			
		||||
from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
 | 
			
		||||
from .symbolic_convert import (
 | 
			
		||||
    DistributedState,
 | 
			
		||||
    ExceptionStack,
 | 
			
		||||
    InstructionTranslator,
 | 
			
		||||
    LocalState,
 | 
			
		||||
    SpeculationLog,
 | 
			
		||||
@ -689,6 +690,7 @@ def _compile(
 | 
			
		||||
        nonlocal output
 | 
			
		||||
        nonlocal tracer
 | 
			
		||||
        speculation_log.restart()
 | 
			
		||||
        exn_vt_stack = ExceptionStack()
 | 
			
		||||
        tracer = InstructionTranslator(
 | 
			
		||||
            instructions,
 | 
			
		||||
            code,
 | 
			
		||||
@ -704,6 +706,7 @@ def _compile(
 | 
			
		||||
            export_constraints,
 | 
			
		||||
            frame_state=frame_state,
 | 
			
		||||
            speculation_log=speculation_log,
 | 
			
		||||
            exn_vt_stack=exn_vt_stack,
 | 
			
		||||
            distributed_state=distributed_state,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -362,7 +362,7 @@ def raise_observed_exception(
 | 
			
		||||
    # CPython here raises an exception. Since there is no python code, we have to manually setup the exception
 | 
			
		||||
    # stack and raise the exception.
 | 
			
		||||
    exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {})  # type: ignore[arg-type]
 | 
			
		||||
    tx.exn_vt_stack.append(exception_vt)
 | 
			
		||||
    tx.exn_vt_stack.set_current_exception(exception_vt)
 | 
			
		||||
    raise observed_exception_map[exc_type]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -391,7 +391,7 @@ def handle_observed_exception(tx: Any) -> None:
 | 
			
		||||
    #
 | 
			
		||||
 | 
			
		||||
    # Fortunately this translates to a simple pop from the exn_vt_stack
 | 
			
		||||
    tx.exn_vt_stack.pop()
 | 
			
		||||
    tx.exn_vt_stack.clear_current_exception()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# These exceptions are ok to fallback to eager/graph_break.
 | 
			
		||||
 | 
			
		||||
@ -44,7 +44,7 @@ import traceback
 | 
			
		||||
import types
 | 
			
		||||
import typing
 | 
			
		||||
import weakref
 | 
			
		||||
from typing import Any, Callable, cast, Optional, Union
 | 
			
		||||
from typing import Any, Callable, cast, NoReturn, Optional, Union
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -383,9 +383,11 @@ class BlockStackEntry:
 | 
			
		||||
            and hasattr(self.with_context, "target_values")
 | 
			
		||||
            and self.with_context.target_values
 | 
			
		||||
        ):
 | 
			
		||||
            return ReenterWith(self.stack_index, tuple(self.with_context.target_values))
 | 
			
		||||
            return ReenterWith(
 | 
			
		||||
                self.stack_index - 1, tuple(self.with_context.target_values)
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            return ReenterWith(self.stack_index)
 | 
			
		||||
            return ReenterWith(self.stack_index - 1)
 | 
			
		||||
 | 
			
		||||
    def exit(self, tx, is_graph_break):
 | 
			
		||||
        assert self.with_context is not None
 | 
			
		||||
@ -956,6 +958,100 @@ class BytecodeDistpatchTableMeta(type):
 | 
			
		||||
        cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class ExceptionStack:
 | 
			
		||||
    """
 | 
			
		||||
    Exception stack that it is shared among all InstructionTranslator instances
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Exception handling in CPython is a bit confusing and some of the bytecode
 | 
			
		||||
    # have a slightly different behavior than what is is documented. While reading
 | 
			
		||||
    # the documentation, is important to notice that the terms "current exception"
 | 
			
		||||
    # and "stack" sometimes refers to a C variable with the same name and the
 | 
			
		||||
    # exception stack, respectively.
 | 
			
		||||
    #
 | 
			
		||||
    # The lifetime of an exception is (Python 3.11+):
 | 
			
		||||
    #  + tx._raise_exception_variable(...) := sets the current_exception variable
 | 
			
		||||
    #  + PUSH_EXC_INFO := pushes the current_exception to the *exception stack*
 | 
			
		||||
    #  + POP_EXCEPT := pops TOS from the *exception stack*
 | 
			
		||||
 | 
			
		||||
    _exc_stack: list[VariableTracker] = dataclasses.field(default_factory=list)
 | 
			
		||||
    _current_exception: Optional[VariableTracker] = dataclasses.field(default=None)
 | 
			
		||||
 | 
			
		||||
    def clear_current_exception(self):
 | 
			
		||||
        self._current_exception = None
 | 
			
		||||
 | 
			
		||||
    def set_current_exception(self, val):
 | 
			
		||||
        self._set_context_and_break_context_reference_cycle(val)
 | 
			
		||||
        self._current_exception = val
 | 
			
		||||
 | 
			
		||||
    def move_current_exception_to_stack(self):
 | 
			
		||||
        assert self._current_exception is not None
 | 
			
		||||
        self.append(self._current_exception)
 | 
			
		||||
        self.clear_current_exception()
 | 
			
		||||
 | 
			
		||||
    def get_current_exception(self):
 | 
			
		||||
        assert self._current_exception is not None
 | 
			
		||||
        return self._current_exception
 | 
			
		||||
 | 
			
		||||
    def _set_context_recursive(self, val, prev_idx):
 | 
			
		||||
        if (ctx := val.__context__) and type(ctx) is not ConstantVariable:
 | 
			
		||||
            return val
 | 
			
		||||
        if len(self._exc_stack) + prev_idx > 0:
 | 
			
		||||
            prev = self._exc_stack[prev_idx]
 | 
			
		||||
            self._set_context_recursive(prev, prev_idx - 1)
 | 
			
		||||
            val.set_context(prev)
 | 
			
		||||
        return val
 | 
			
		||||
 | 
			
		||||
    def _break_context_reference_cycle(self, val):
 | 
			
		||||
        # See test_exceptions::test_raise_does_not_create_context_chain_cycle
 | 
			
		||||
        # Based on https://github.com/python/cpython/blob/e635bf2e49797ecb976ce45a67fce2201a25ca68/Python/errors.c#L207-L228
 | 
			
		||||
        # As noted on CPython, this is O(chain length) but the context chains
 | 
			
		||||
        # are usually very small
 | 
			
		||||
        o = slow_o = val
 | 
			
		||||
        slow_update_toggle = False  # floyd's algorithm for detecting cycle
 | 
			
		||||
        while True:
 | 
			
		||||
            context = o.__context__
 | 
			
		||||
            if type(context) is ConstantVariable:  # context not set
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
            if context is val:
 | 
			
		||||
                o.set_context(ConstantVariable(None))
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
            o = context
 | 
			
		||||
            if o is slow_o:
 | 
			
		||||
                # pre-existing cycle - all exceptions on the path were
 | 
			
		||||
                # visited and checked
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
            if slow_update_toggle:
 | 
			
		||||
                slow_o = slow_o.__context__  # visited all exceptions
 | 
			
		||||
            slow_update_toggle = not slow_update_toggle
 | 
			
		||||
 | 
			
		||||
    def _set_context_and_break_context_reference_cycle(self, val):
 | 
			
		||||
        # set Exception.__context__
 | 
			
		||||
        self._set_context_recursive(val, len(self._exc_stack) - 1)
 | 
			
		||||
        self._break_context_reference_cycle(val)
 | 
			
		||||
 | 
			
		||||
    def pop(self):
 | 
			
		||||
        return self._exc_stack.pop()
 | 
			
		||||
 | 
			
		||||
    def append(self, val):
 | 
			
		||||
        self._exc_stack.append(val)
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self._exc_stack)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index):
 | 
			
		||||
        return self._exc_stack[index]
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"{self._exc_stack=} - {self._current_exception=}"
 | 
			
		||||
 | 
			
		||||
    __repr__ = __str__
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InstructionTranslatorBase(
 | 
			
		||||
    metaclass=BytecodeDistpatchTableMeta,
 | 
			
		||||
):
 | 
			
		||||
@ -975,7 +1071,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
    inconsistent_side_effects: bool
 | 
			
		||||
    current_speculation: Optional[SpeculationEntry]
 | 
			
		||||
    dispatch_table: list[Any]
 | 
			
		||||
    exn_vt_stack: list[VariableTracker]
 | 
			
		||||
    exn_vt_stack: ExceptionStack
 | 
			
		||||
    exec_recorder: Optional[ExecutionRecorder]
 | 
			
		||||
    strict_checks_fn: Optional[Callable[[VariableTracker], bool]]
 | 
			
		||||
    start_point: Optional[int]
 | 
			
		||||
@ -1655,43 +1751,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
                self.push(ConstantVariable.create(None))
 | 
			
		||||
            self.jump(inst)
 | 
			
		||||
 | 
			
		||||
    def _raise_exception_variable(self, inst):
 | 
			
		||||
        def set_context_recursive(val, prev_idx):
 | 
			
		||||
            if (ctx := val.__context__) and type(ctx) is not ConstantVariable:
 | 
			
		||||
                return val
 | 
			
		||||
            if len(self.exn_vt_stack) + prev_idx > 0:
 | 
			
		||||
                prev = self.exn_vt_stack[prev_idx]
 | 
			
		||||
                set_context_recursive(prev, prev_idx - 1)
 | 
			
		||||
                val.set_context(prev)
 | 
			
		||||
            return val
 | 
			
		||||
 | 
			
		||||
        def break_context_reference_cycle(val):
 | 
			
		||||
            # See test_exceptions::test_raise_does_not_create_context_chain_cycle
 | 
			
		||||
            # Based on https://github.com/python/cpython/blob/e635bf2e49797ecb976ce45a67fce2201a25ca68/Python/errors.c#L207-L228
 | 
			
		||||
            # As noted on CPython, this is O(chain length) but the context chains
 | 
			
		||||
            # are usually very small
 | 
			
		||||
            o = slow_o = val
 | 
			
		||||
            slow_update_toggle = False  # floyd's algorithm for detecting cycle
 | 
			
		||||
            while True:
 | 
			
		||||
                context = o.__context__
 | 
			
		||||
                if type(context) is ConstantVariable:  # context not set
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
                if context is val:
 | 
			
		||||
                    o.set_context(ConstantVariable(None))
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
                o = context
 | 
			
		||||
                if o is slow_o:
 | 
			
		||||
                    # pre-existing cycle - all exceptions on the path were
 | 
			
		||||
                    # visited and checked
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
                if slow_update_toggle:
 | 
			
		||||
                    slow_o = slow_o.__context__  # visited all exceptions
 | 
			
		||||
                slow_update_toggle = not slow_update_toggle
 | 
			
		||||
 | 
			
		||||
        val = self.pop()
 | 
			
		||||
    def _raise_exception_variable(self, val) -> NoReturn:
 | 
			
		||||
        # User can raise exception in 2 ways
 | 
			
		||||
        #   1) raise exception type - raise NotImplementedError
 | 
			
		||||
        #   2) raise execption instance - raise NotImplemetedError("foo")
 | 
			
		||||
@ -1705,6 +1765,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
            val = val.call_function(self, [], {})  # type: ignore[arg-type]
 | 
			
		||||
 | 
			
		||||
        # Handle https://peps.python.org/pep-0479/
 | 
			
		||||
        # CPython 3.12+ has a specific bytecode instruction (CALL_INTRINSIC_1 3) for this
 | 
			
		||||
        if (
 | 
			
		||||
            is_generator(self.f_code)
 | 
			
		||||
            and isinstance(val, variables.ExceptionVariable)
 | 
			
		||||
@ -1712,12 +1773,8 @@ class InstructionTranslatorBase(
 | 
			
		||||
        ):
 | 
			
		||||
            val = variables.BuiltinVariable(RuntimeError).call_function(self, [], {})  # type: ignore[arg-type]
 | 
			
		||||
 | 
			
		||||
        # set Exception.__context__
 | 
			
		||||
        set_context_recursive(val, len(self.exn_vt_stack) - 1)
 | 
			
		||||
        break_context_reference_cycle(val)
 | 
			
		||||
 | 
			
		||||
        # Save the exception in a global data structure
 | 
			
		||||
        self.exn_vt_stack.append(val)
 | 
			
		||||
        self.exn_vt_stack.set_current_exception(val)
 | 
			
		||||
 | 
			
		||||
        # 2) when user raises exception instance
 | 
			
		||||
        if self._isinstance_exception(val):
 | 
			
		||||
@ -1732,26 +1789,29 @@ class InstructionTranslatorBase(
 | 
			
		||||
 | 
			
		||||
    def RAISE_VARARGS(self, inst):
 | 
			
		||||
        if inst.arg == 0:
 | 
			
		||||
            # duplicate the top of the stack and re-raise it
 | 
			
		||||
            if sys.version_info < (3, 11):
 | 
			
		||||
                unimplemented_v2(
 | 
			
		||||
                    gb_type="Re-raise with no arguments",
 | 
			
		||||
                    context="",
 | 
			
		||||
                    explanation="Dynamo does not support re-raising the previous exception "
 | 
			
		||||
                    "in Python < 3.11 (empty `raise`)",
 | 
			
		||||
                    hints=[],
 | 
			
		||||
                )
 | 
			
		||||
            assert self._isinstance_exception(self.stack[-1])
 | 
			
		||||
            self.stack.append(self.stack[-1])
 | 
			
		||||
            self._raise_exception_variable(inst)
 | 
			
		||||
            # re-raise the previous exception. Here CPython refers to the exception
 | 
			
		||||
            # on top of the exception stack
 | 
			
		||||
            assert len(self.exn_vt_stack)
 | 
			
		||||
            val = self.exn_vt_stack[-1]
 | 
			
		||||
            assert self._isinstance_exception(val), val
 | 
			
		||||
            self._raise_exception_variable(val)
 | 
			
		||||
        elif inst.arg == 1:
 | 
			
		||||
            self._raise_exception_variable(inst)
 | 
			
		||||
            # raise TOS
 | 
			
		||||
            val = self.stack[-1]
 | 
			
		||||
            self._raise_exception_variable(val)
 | 
			
		||||
        else:
 | 
			
		||||
            # Support raise .. from None ... Dynamo does not track __cause__ and other attributes of exception. So we
 | 
			
		||||
            # ignore `from None` part.
 | 
			
		||||
            # raise .. from None
 | 
			
		||||
            from_vt = self.pop()
 | 
			
		||||
            if isinstance(from_vt, ConstantVariable) and from_vt.value is None:
 | 
			
		||||
                self._raise_exception_variable(inst)
 | 
			
		||||
                val = self.pop()
 | 
			
		||||
                try:
 | 
			
		||||
                    self._raise_exception_variable(val)
 | 
			
		||||
                finally:
 | 
			
		||||
                    # Update __cause__/__supppress_context__ in the raised exception
 | 
			
		||||
                    curr_exc = self.exn_vt_stack.get_current_exception()
 | 
			
		||||
                    curr_exc.call_setattr(
 | 
			
		||||
                        self, ConstantVariable("__cause__"), ConstantVariable(None)
 | 
			
		||||
                    )
 | 
			
		||||
            unimplemented_v2(
 | 
			
		||||
                gb_type="Re-raise with 2 arguments",
 | 
			
		||||
                context=str(from_vt),
 | 
			
		||||
@ -1774,16 +1834,27 @@ class InstructionTranslatorBase(
 | 
			
		||||
            self.RERAISE(inst)
 | 
			
		||||
 | 
			
		||||
    def RERAISE(self, inst):
 | 
			
		||||
        # https://docs.python.org/3/library/dis.html#opcode-RERAISE
 | 
			
		||||
        #   Re-raises the exception currently on top of the stack. If oparg is
 | 
			
		||||
        #   non-zero, pops an additional value from the stack which is used to
 | 
			
		||||
        #   set f_lasti of the current frame.
 | 
			
		||||
 | 
			
		||||
        if sys.version_info >= (3, 11):
 | 
			
		||||
            # RERAISE is currently supported in a narrow case of `raise ... from None`
 | 
			
		||||
            self._raise_exception_variable(inst)
 | 
			
		||||
        unimplemented_v2(
 | 
			
		||||
            gb_type="RERAISE in Python < 3.11",
 | 
			
		||||
            context="",
 | 
			
		||||
            explanation="RERAISE bytecode (https://docs.python.org/3.10/library/dis.html#opcode-RERAISE) "
 | 
			
		||||
            "not supported in Python < 3.11. This bytecode is generated by try/with blocks.",
 | 
			
		||||
            hints=[],
 | 
			
		||||
        )
 | 
			
		||||
            val = self.pop()
 | 
			
		||||
            if inst.argval:
 | 
			
		||||
                # RERAISE 1
 | 
			
		||||
                _ = self.pop()
 | 
			
		||||
                self._raise_exception_variable(val)
 | 
			
		||||
            else:
 | 
			
		||||
                # RERAISE 0
 | 
			
		||||
                self.push(val)
 | 
			
		||||
                self._raise_exception_variable(val)
 | 
			
		||||
        else:
 | 
			
		||||
            _exc = self.pop()
 | 
			
		||||
            val = self.pop()
 | 
			
		||||
            _tb = self.pop()
 | 
			
		||||
            self._raise_exception_variable(val)
 | 
			
		||||
 | 
			
		||||
    def _isinstance_exception(self, val):
 | 
			
		||||
        return isinstance(
 | 
			
		||||
@ -1813,7 +1884,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
        else:
 | 
			
		||||
            assert len(self.stack) >= 7
 | 
			
		||||
            fn = self.stack[-7]
 | 
			
		||||
            val = self.stack[-4]
 | 
			
		||||
            val = self.stack[-2]
 | 
			
		||||
            assert self._isinstance_exception(val)
 | 
			
		||||
            typ = BuiltinVariable(val.exc_type)  # type: ignore[attr-defined]
 | 
			
		||||
            tb = ConstantVariable(None)
 | 
			
		||||
@ -1843,8 +1914,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                # 3) push the exception to the stack
 | 
			
		||||
                assert len(self.exn_vt_stack)
 | 
			
		||||
                self.push(self.exn_vt_stack[-1])
 | 
			
		||||
                self.push(self.exn_vt_stack.get_current_exception())
 | 
			
		||||
 | 
			
		||||
                # 4) jump to the handler
 | 
			
		||||
                self.jump(exn_tab_entry)
 | 
			
		||||
@ -1867,15 +1937,13 @@ class InstructionTranslatorBase(
 | 
			
		||||
            if len(self.block_stack):
 | 
			
		||||
                # base implementation - https://github.com/python/cpython/blob/3.10/Python/ceval.c#L4455
 | 
			
		||||
 | 
			
		||||
                assert len(self.exn_vt_stack)
 | 
			
		||||
                exception_var = self.exn_vt_stack[-1]
 | 
			
		||||
 | 
			
		||||
                block_stack_entry = self.block_stack.pop()
 | 
			
		||||
 | 
			
		||||
                while block_stack_entry.inst.opname == "EXCEPT_HANDLER":
 | 
			
		||||
                    # TODO(anijain2305) - This is not tested .. unable to create a testcase
 | 
			
		||||
                    # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456
 | 
			
		||||
                    self.popn(3)
 | 
			
		||||
                    self.exn_vt_stack.pop()
 | 
			
		||||
                    if len(self.block_stack) == 0:
 | 
			
		||||
                        # No handler found in this frame. Bubble the exception to the parent
 | 
			
		||||
                        # instruction translater.
 | 
			
		||||
@ -1892,15 +1960,8 @@ class InstructionTranslatorBase(
 | 
			
		||||
                        raise raised_exception
 | 
			
		||||
                    block_stack_entry = self.block_stack.pop()
 | 
			
		||||
 | 
			
		||||
                if block_stack_entry.inst.opname != "SETUP_FINALLY":
 | 
			
		||||
                    unimplemented_v2(
 | 
			
		||||
                        gb_type="Exception raised with invalid exception handler",
 | 
			
		||||
                        context="",
 | 
			
		||||
                        explanation="Exception raised when top of the block stack "
 | 
			
		||||
                        "is not exception handler (e.g. try .. with .. except). "
 | 
			
		||||
                        f"Current TOS is {block_stack_entry.inst}",
 | 
			
		||||
                        hints=[],
 | 
			
		||||
                    )
 | 
			
		||||
                exception_var = self.exn_vt_stack.get_current_exception()
 | 
			
		||||
                self.exn_vt_stack.move_current_exception_to_stack()
 | 
			
		||||
 | 
			
		||||
                # 1) pop values from the stack until it matches the stack depth
 | 
			
		||||
                # for the handler
 | 
			
		||||
@ -1954,15 +2015,37 @@ class InstructionTranslatorBase(
 | 
			
		||||
                raise raised_exception
 | 
			
		||||
 | 
			
		||||
    def PUSH_EXC_INFO(self, inst):
 | 
			
		||||
        # https://docs.python.org/3/library/dis.html#opcode-PUSH_EXC_INFO
 | 
			
		||||
        #   Pops a value from the stack. Pushes the current exception to the top
 | 
			
		||||
        #   of the stack. Pushes the value originally popped back to the stack.
 | 
			
		||||
        #
 | 
			
		||||
        # The behavior of this opcode in CPython is a bit different than what it
 | 
			
		||||
        # is described. It pops a value from the stack, pushes the top of the
 | 
			
		||||
        # exception stack to the interpreter stack and moves the
 | 
			
		||||
        # "current exception" to the exception stack.
 | 
			
		||||
        #
 | 
			
		||||
        # As an example, suppose the stack is in the following state:
 | 
			
		||||
        #   + stack = [..., ConstantVariable(1), ConstantVariable(2)]
 | 
			
		||||
        #   + current_exception = TypeError
 | 
			
		||||
        #   + exception_stack = [ValueError]
 | 
			
		||||
        #
 | 
			
		||||
        # After PUSH_EXC_INFO is executed
 | 
			
		||||
        #   + stack = [..., ConstantVariable(1), ValueError, ConstantVariable(2)]
 | 
			
		||||
        #   + current_exception = None
 | 
			
		||||
        #   + exception_stack = [ValueError, TypeError]
 | 
			
		||||
 | 
			
		||||
        val = self.pop()
 | 
			
		||||
        assert len(self.exn_vt_stack)
 | 
			
		||||
        self.push(self.exn_vt_stack[-1])
 | 
			
		||||
        if len(self.exn_vt_stack) == 0:
 | 
			
		||||
            prev_exc = ConstantVariable(None)
 | 
			
		||||
        else:
 | 
			
		||||
            prev_exc = self.exn_vt_stack[-1]
 | 
			
		||||
        self.push(prev_exc)
 | 
			
		||||
        self.push(val)
 | 
			
		||||
        self.exn_vt_stack.move_current_exception_to_stack()
 | 
			
		||||
 | 
			
		||||
    def POP_EXCEPT(self, inst):
 | 
			
		||||
        if sys.version_info >= (3, 11):
 | 
			
		||||
            val = self.pop()
 | 
			
		||||
            assert self._isinstance_exception(val)
 | 
			
		||||
            _ = self.pop()
 | 
			
		||||
            # This exception is handled and therefore we can clear the error indicator
 | 
			
		||||
            assert len(self.exn_vt_stack)
 | 
			
		||||
            self.exn_vt_stack.pop()
 | 
			
		||||
@ -2916,6 +2999,8 @@ class InstructionTranslatorBase(
 | 
			
		||||
        else:
 | 
			
		||||
            target = inst.target
 | 
			
		||||
 | 
			
		||||
        self.push(exit)
 | 
			
		||||
 | 
			
		||||
        if target:
 | 
			
		||||
            if isinstance(self, InstructionTranslator):
 | 
			
		||||
                self.block_stack.append(
 | 
			
		||||
@ -2924,7 +3009,6 @@ class InstructionTranslatorBase(
 | 
			
		||||
            else:
 | 
			
		||||
                self.block_stack.append(BlockStackEntry(inst, target, len(self.stack)))
 | 
			
		||||
 | 
			
		||||
        self.push(exit)
 | 
			
		||||
        self.push(ctx.enter(self))
 | 
			
		||||
 | 
			
		||||
    def append_prefix_inst(self, inst):
 | 
			
		||||
@ -3125,6 +3209,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
        export: bool,
 | 
			
		||||
        inline_depth: int,
 | 
			
		||||
        speculation_log: SpeculationLog,
 | 
			
		||||
        exn_vt_stack: ExceptionStack,
 | 
			
		||||
        distributed_state: Optional[DistributedState],
 | 
			
		||||
        # This determines whether to use the execution recorder.
 | 
			
		||||
        closure: Optional[tuple[types.CellType]] = None,
 | 
			
		||||
@ -3149,7 +3234,7 @@ class InstructionTranslatorBase(
 | 
			
		||||
        self.kw_names = None
 | 
			
		||||
        self.accept_prefix_inst = True
 | 
			
		||||
        self.prefix_insts = []
 | 
			
		||||
        self.exn_vt_stack = []
 | 
			
		||||
        self.exn_vt_stack = exn_vt_stack
 | 
			
		||||
 | 
			
		||||
        # Properties of the input/output code
 | 
			
		||||
        self.instructions: list[Instruction] = instructions
 | 
			
		||||
@ -3233,6 +3318,7 @@ class InstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
        export_constraints,
 | 
			
		||||
        frame_state,
 | 
			
		||||
        speculation_log: SpeculationLog,
 | 
			
		||||
        exn_vt_stack: ExceptionStack,
 | 
			
		||||
        distributed_state: Optional[DistributedState],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        _step_logger()(
 | 
			
		||||
@ -3266,6 +3352,7 @@ class InstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
            export=export,
 | 
			
		||||
            inline_depth=0,
 | 
			
		||||
            speculation_log=speculation_log,
 | 
			
		||||
            exn_vt_stack=exn_vt_stack,
 | 
			
		||||
            distributed_state=distributed_state,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -3818,9 +3905,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
                self.run()
 | 
			
		||||
        except exc.ObservedException as e:
 | 
			
		||||
            msg = f"Observed exception DURING INLING {code} : {e}"
 | 
			
		||||
            # TODO(anijain2305) - This works but we should probably have a
 | 
			
		||||
            # global/central data structure for the exception stack.
 | 
			
		||||
            parent.exn_vt_stack.extend(self.exn_vt_stack)
 | 
			
		||||
            log.debug(msg)
 | 
			
		||||
            # bubble up the exception to the parent frame.
 | 
			
		||||
            raise
 | 
			
		||||
@ -3895,6 +3979,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
 | 
			
		||||
            export=parent.export,
 | 
			
		||||
            inline_depth=parent.inline_depth + 1,
 | 
			
		||||
            speculation_log=parent.speculation_log,
 | 
			
		||||
            exn_vt_stack=parent.exn_vt_stack,
 | 
			
		||||
            distributed_state=parent.distributed_state,
 | 
			
		||||
        )
 | 
			
		||||
        self.funcvar = funcvar
 | 
			
		||||
 | 
			
		||||
@ -41,7 +41,6 @@ from ..bytecode_transformation import create_call_function, create_rot_n, is_gen
 | 
			
		||||
from ..exc import (
 | 
			
		||||
    get_dynamo_observed_exception,
 | 
			
		||||
    handle_observed_exception,
 | 
			
		||||
    IncorrectUsage,
 | 
			
		||||
    InfiniteGeneratorError,
 | 
			
		||||
    ObservedException,
 | 
			
		||||
    ObservedGeneratorExit,
 | 
			
		||||
@ -521,7 +520,6 @@ class LocalGeneratorObjectVariable(VariableTracker):
 | 
			
		||||
            with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
 | 
			
		||||
                return tracer.inline_call_()
 | 
			
		||||
        except ObservedException as e:
 | 
			
		||||
            tx.exn_vt_stack.extend(tracer.exn_vt_stack)
 | 
			
		||||
            raise e
 | 
			
		||||
        except InfiniteGeneratorError:
 | 
			
		||||
            # test/dynamo/test_misc.py::test_iterator_limit
 | 
			
		||||
@ -550,9 +548,8 @@ class LocalGeneratorObjectVariable(VariableTracker):
 | 
			
		||||
 | 
			
		||||
    def _setup_exception(self, tx, exc):
 | 
			
		||||
        tracer = self._get_inline_tracer(tx)
 | 
			
		||||
        tracer.push(exc)
 | 
			
		||||
        try:
 | 
			
		||||
            tracer._raise_exception_variable(None)
 | 
			
		||||
            tracer._raise_exception_variable(exc)
 | 
			
		||||
        except ObservedException as e:
 | 
			
		||||
            # if no handler is available (i.e. user code doesn't catch it), the
 | 
			
		||||
            # exception is raised again.
 | 
			
		||||
@ -664,19 +661,16 @@ class LocalGeneratorObjectVariable(VariableTracker):
 | 
			
		||||
            # * If the generator function does not catch the passed-in exception,
 | 
			
		||||
            # or raises a different exception, then that exception propagates to the caller.
 | 
			
		||||
 | 
			
		||||
            if len(args) > 1:
 | 
			
		||||
                raise IncorrectUsage(
 | 
			
		||||
                    "the (type, exc, tb) signature of throw() is deprecated, "
 | 
			
		||||
                    "use the single-arg signature instead."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # Setup the exception table and jump target in case of try...finally
 | 
			
		||||
            tracer = self._get_inline_tracer(tx)
 | 
			
		||||
            try:
 | 
			
		||||
                self._setup_exception(tx, args[0])
 | 
			
		||||
            except ObservedException:
 | 
			
		||||
                # In Python 3.9, the exception is represented as a triple (typ, val, tb)
 | 
			
		||||
                # In such cases, we re-raise the exception object given to avoid
 | 
			
		||||
                # creating a new object, so that IS_OP works.
 | 
			
		||||
                # See: https://github.com/pytorch/pytorch/pull/146496
 | 
			
		||||
                self._setup_exception(tx, args[1] if len(args) == 3 else args[0])
 | 
			
		||||
            except ObservedException:  # noqa: TRY203
 | 
			
		||||
                # propagate the exception back to the parent caller
 | 
			
		||||
                tx.exn_vt_stack.extend(tracer.exn_vt_stack)
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
            retval = self.next_variable(tx)
 | 
			
		||||
@ -749,9 +743,6 @@ class LocalGeneratorObjectVariable(VariableTracker):
 | 
			
		||||
            except get_dynamo_observed_exception(exc_type):
 | 
			
		||||
                # We should get back the exception raised before.
 | 
			
		||||
                pass
 | 
			
		||||
            except ObservedException:
 | 
			
		||||
                # Propagate anything else back to the parent caller
 | 
			
		||||
                tx.exn_vt_stack.extend(tracer.exn_vt_stack)
 | 
			
		||||
            else:
 | 
			
		||||
                raise_observed_exception(RuntimeError, tracer)
 | 
			
		||||
            return retval
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user