diff --git a/test/dynamo/cpython/3_13/test_generators.diff b/test/dynamo/cpython/3_13/test_generators.diff index 338d51894fb3..8d7c0bfd2102 100644 --- a/test/dynamo/cpython/3_13/test_generators.diff +++ b/test/dynamo/cpython/3_13/test_generators.diff @@ -1,5 +1,5 @@ diff --git a/test/dynamo/cpython/3_13/test_generators.py b/test/dynamo/cpython/3_13/test_generators.py -index e48d79d34f4..a48da0914b9 100644 +index 515fe7407f1..a48da0914b9 100644 --- a/test/dynamo/cpython/3_13/test_generators.py +++ b/test/dynamo/cpython/3_13/test_generators.py @@ -1,3 +1,56 @@ @@ -105,7 +105,8 @@ index e48d79d34f4..a48da0914b9 100644 + return self.val + + # No __iter__ method -+ + +-class ModifyUnderlyingIterableTest(unittest.TestCase): + class C: + + def __iter__(self): @@ -113,8 +114,7 @@ index e48d79d34f4..a48da0914b9 100644 + + self.assertEqual([1,2], list(i for i in C())) + - --class ModifyUnderlyingIterableTest(unittest.TestCase): ++ +class ModifyUnderlyingIterableTest(__TestCase): iterables = [ range(0), @@ -137,99 +137,16 @@ index e48d79d34f4..a48da0914b9 100644 def test_close_no_return_value(self): def f(): -@@ -630,90 +706,7 @@ class GeneratorCloseTest(unittest.TestCase): +@@ -630,7 +706,7 @@ class GeneratorCloseTest(unittest.TestCase): self.assertIsNone(f_wr()) --# See https://github.com/python/cpython/issues/125723 --class GeneratorDeallocTest(unittest.TestCase): -- def test_frame_outlives_generator(self): -- def g1(): -- a = 42 -- yield sys._getframe() -- -- def g2(): -- a = 42 -- yield -- -- def g3(obj): -- a = 42 -- obj.frame = sys._getframe() -- yield -- -- class ObjectWithFrame(): -- def __init__(self): -- self.frame = None -- -- def get_frame(index): -- if index == 1: -- return next(g1()) -- elif index == 2: -- gen = g2() -- next(gen) -- return gen.gi_frame -- elif index == 3: -- obj = ObjectWithFrame() -- next(g3(obj)) -- return obj.frame -- else: -- return None -- -- for index in (1, 2, 3): -- with self.subTest(index=index): -- frame = get_frame(index) -- frame_locals = frame.f_locals -- self.assertIn('a', frame_locals) -- self.assertEqual(frame_locals['a'], 42) -- -- def test_frame_locals_outlive_generator(self): -- frame_locals1 = None -- -- def g1(): -- nonlocal frame_locals1 -- frame_locals1 = sys._getframe().f_locals -- a = 42 -- yield -- -- def g2(): -- a = 42 -- yield sys._getframe().f_locals -- -- def get_frame_locals(index): -- if index == 1: -- nonlocal frame_locals1 -- next(g1()) -- return frame_locals1 -- if index == 2: -- return next(g2()) -- else: -- return None -- -- for index in (1, 2): -- with self.subTest(index=index): -- frame_locals = get_frame_locals(index) -- self.assertIn('a', frame_locals) -- self.assertEqual(frame_locals['a'], 42) -- -- def test_frame_locals_outlive_generator_with_exec(self): -- def g(): -- a = 42 -- yield locals(), sys._getframe().f_locals -- -- locals_ = {'g': g} -- for i in range(10): -- exec("snapshot, live_locals = next(g())", locals=locals_) -- for l in (locals_['snapshot'], locals_['live_locals']): -- self.assertIn('a', l) -- self.assertEqual(l['a'], 42) -- -- -class GeneratorThrowTest(unittest.TestCase): +class GeneratorThrowTest(__TestCase): def test_exception_context_with_yield(self): def f(): -@@ -812,7 +805,7 @@ class GeneratorThrowTest(unittest.TestCase): +@@ -729,7 +805,7 @@ class GeneratorThrowTest(unittest.TestCase): gen.throw(ValueError) @@ -238,7 +155,7 @@ index e48d79d34f4..a48da0914b9 100644 def check_stack_names(self, frame, expected): names = [] -@@ -861,7 +854,7 @@ class GeneratorStackTraceTest(unittest.TestCase): +@@ -778,7 +854,7 @@ class GeneratorStackTraceTest(unittest.TestCase): self.check_yield_from_example(call_throw) @@ -247,7 +164,7 @@ index e48d79d34f4..a48da0914b9 100644 def test_generator_gi_yieldfrom(self): def a(): self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_RUNNING) -@@ -2752,21 +2745,27 @@ test_generators just happened to be the test that drew these out. +@@ -2669,21 +2745,27 @@ test_generators just happened to be the test that drew these out. """ diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index 9d7318105c90..cfb3241d712d 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -1515,6 +1515,76 @@ class TestGeneratorThrow(GeneratorTestsBase): self._compile_check(fn) + def test_return_const_value_in_except_and_finally(self): + def whoo(): + try: + yield 1 + except ValueError: + return 2 # noqa: B901 + finally: + return 3 # noqa: B012, SIM107, B901 + + def fn(t): + gen = whoo() + next(gen) + try: + gen.throw(ValueError) + except StopIteration as e: + assert e.args[0] == 3 + except Exception as e: + raise AssertionError from e + return t.sin() + + self._compile_check(fn) + + def test_return_value_in_except_and_finally(self): + class Foo: + def __init__(self, x): + self.x = x + + def whoo(): + try: + yield 1 + except ValueError: + return Foo(2) # noqa: B901 + finally: + return Foo(3) # noqa: B012, SIM107, B901 + + def fn(t): + gen = whoo() + next(gen) + try: + gen.throw(ValueError) + except StopIteration as e: + assert e.args[0].x == 3 + except Exception as e: + raise AssertionError from e + return t.sin() + + self._compile_check(fn) + + def test_return_None_in_except_and_finally(self): + def whoo(): + try: + yield 1 + except ValueError: + return 2 # noqa: B901 + finally: + return # noqa: B012, SIM107 + + def fn(t): + gen = whoo() + next(gen) + try: + gen.throw(ValueError) + except StopIteration as e: + assert len(e.args) == 0 + except Exception as e: + raise AssertionError from e + return t.sin() + + self._compile_check(fn) + instantiate_parametrized_tests(GeneratorTests) instantiate_parametrized_tests(TestGeneratorSend) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 8e5a1ef80393..087569127495 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3982,7 +3982,13 @@ class InliningInstructionTranslator(InstructionTranslatorBase): ): assert isinstance(self, InliningGeneratorInstructionTranslator) # When the generator returns None, we raise StopIteration - exc.raise_observed_exception(StopIteration, self) + args = [] + if not ( + isinstance(self.symbolic_result, ConstantVariable) + and self.symbolic_result.value is None + ): + args = [self.symbolic_result] + exc.raise_observed_exception(StopIteration, self, args=args) else: return self.symbolic_result else: