[generator] Raise StopIteration(value) with value from the return stmt (#157152)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157152
Approved by: https://github.com/zou3519
ghstack dependencies: #157148
This commit is contained in:
Guilherme Leobas
2025-08-13 21:16:44 -03:00
committed by PyTorch MergeBot
parent 831e85104a
commit d387a48c38
3 changed files with 85 additions and 92 deletions

View File

@ -1,5 +1,5 @@
diff --git a/test/dynamo/cpython/3_13/test_generators.py b/test/dynamo/cpython/3_13/test_generators.py diff --git a/test/dynamo/cpython/3_13/test_generators.py b/test/dynamo/cpython/3_13/test_generators.py
index e48d79d34f4..a48da0914b9 100644 index 515fe7407f1..a48da0914b9 100644
--- a/test/dynamo/cpython/3_13/test_generators.py --- a/test/dynamo/cpython/3_13/test_generators.py
+++ b/test/dynamo/cpython/3_13/test_generators.py +++ b/test/dynamo/cpython/3_13/test_generators.py
@@ -1,3 +1,56 @@ @@ -1,3 +1,56 @@
@ -105,7 +105,8 @@ index e48d79d34f4..a48da0914b9 100644
+ return self.val + return self.val
+ +
+ # No __iter__ method + # No __iter__ method
+
-class ModifyUnderlyingIterableTest(unittest.TestCase):
+ class C: + class C:
+ +
+ def __iter__(self): + def __iter__(self):
@ -113,8 +114,7 @@ index e48d79d34f4..a48da0914b9 100644
+ +
+ self.assertEqual([1,2], list(i for i in C())) + self.assertEqual([1,2], list(i for i in C()))
+ +
+
-class ModifyUnderlyingIterableTest(unittest.TestCase):
+class ModifyUnderlyingIterableTest(__TestCase): +class ModifyUnderlyingIterableTest(__TestCase):
iterables = [ iterables = [
range(0), range(0),
@ -137,99 +137,16 @@ index e48d79d34f4..a48da0914b9 100644
def test_close_no_return_value(self): def test_close_no_return_value(self):
def f(): def f():
@@ -630,90 +706,7 @@ class GeneratorCloseTest(unittest.TestCase): @@ -630,7 +706,7 @@ class GeneratorCloseTest(unittest.TestCase):
self.assertIsNone(f_wr()) self.assertIsNone(f_wr())
-# See https://github.com/python/cpython/issues/125723
-class GeneratorDeallocTest(unittest.TestCase):
- def test_frame_outlives_generator(self):
- def g1():
- a = 42
- yield sys._getframe()
-
- def g2():
- a = 42
- yield
-
- def g3(obj):
- a = 42
- obj.frame = sys._getframe()
- yield
-
- class ObjectWithFrame():
- def __init__(self):
- self.frame = None
-
- def get_frame(index):
- if index == 1:
- return next(g1())
- elif index == 2:
- gen = g2()
- next(gen)
- return gen.gi_frame
- elif index == 3:
- obj = ObjectWithFrame()
- next(g3(obj))
- return obj.frame
- else:
- return None
-
- for index in (1, 2, 3):
- with self.subTest(index=index):
- frame = get_frame(index)
- frame_locals = frame.f_locals
- self.assertIn('a', frame_locals)
- self.assertEqual(frame_locals['a'], 42)
-
- def test_frame_locals_outlive_generator(self):
- frame_locals1 = None
-
- def g1():
- nonlocal frame_locals1
- frame_locals1 = sys._getframe().f_locals
- a = 42
- yield
-
- def g2():
- a = 42
- yield sys._getframe().f_locals
-
- def get_frame_locals(index):
- if index == 1:
- nonlocal frame_locals1
- next(g1())
- return frame_locals1
- if index == 2:
- return next(g2())
- else:
- return None
-
- for index in (1, 2):
- with self.subTest(index=index):
- frame_locals = get_frame_locals(index)
- self.assertIn('a', frame_locals)
- self.assertEqual(frame_locals['a'], 42)
-
- def test_frame_locals_outlive_generator_with_exec(self):
- def g():
- a = 42
- yield locals(), sys._getframe().f_locals
-
- locals_ = {'g': g}
- for i in range(10):
- exec("snapshot, live_locals = next(g())", locals=locals_)
- for l in (locals_['snapshot'], locals_['live_locals']):
- self.assertIn('a', l)
- self.assertEqual(l['a'], 42)
-
-
-class GeneratorThrowTest(unittest.TestCase): -class GeneratorThrowTest(unittest.TestCase):
+class GeneratorThrowTest(__TestCase): +class GeneratorThrowTest(__TestCase):
def test_exception_context_with_yield(self): def test_exception_context_with_yield(self):
def f(): def f():
@@ -812,7 +805,7 @@ class GeneratorThrowTest(unittest.TestCase): @@ -729,7 +805,7 @@ class GeneratorThrowTest(unittest.TestCase):
gen.throw(ValueError) gen.throw(ValueError)
@ -238,7 +155,7 @@ index e48d79d34f4..a48da0914b9 100644
def check_stack_names(self, frame, expected): def check_stack_names(self, frame, expected):
names = [] names = []
@@ -861,7 +854,7 @@ class GeneratorStackTraceTest(unittest.TestCase): @@ -778,7 +854,7 @@ class GeneratorStackTraceTest(unittest.TestCase):
self.check_yield_from_example(call_throw) self.check_yield_from_example(call_throw)
@ -247,7 +164,7 @@ index e48d79d34f4..a48da0914b9 100644
def test_generator_gi_yieldfrom(self): def test_generator_gi_yieldfrom(self):
def a(): def a():
self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_RUNNING) self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_RUNNING)
@@ -2752,21 +2745,27 @@ test_generators just happened to be the test that drew these out. @@ -2669,21 +2745,27 @@ test_generators just happened to be the test that drew these out.
""" """

View File

@ -1515,6 +1515,76 @@ class TestGeneratorThrow(GeneratorTestsBase):
self._compile_check(fn) self._compile_check(fn)
def test_return_const_value_in_except_and_finally(self):
def whoo():
try:
yield 1
except ValueError:
return 2 # noqa: B901
finally:
return 3 # noqa: B012, SIM107, B901
def fn(t):
gen = whoo()
next(gen)
try:
gen.throw(ValueError)
except StopIteration as e:
assert e.args[0] == 3
except Exception as e:
raise AssertionError from e
return t.sin()
self._compile_check(fn)
def test_return_value_in_except_and_finally(self):
class Foo:
def __init__(self, x):
self.x = x
def whoo():
try:
yield 1
except ValueError:
return Foo(2) # noqa: B901
finally:
return Foo(3) # noqa: B012, SIM107, B901
def fn(t):
gen = whoo()
next(gen)
try:
gen.throw(ValueError)
except StopIteration as e:
assert e.args[0].x == 3
except Exception as e:
raise AssertionError from e
return t.sin()
self._compile_check(fn)
def test_return_None_in_except_and_finally(self):
def whoo():
try:
yield 1
except ValueError:
return 2 # noqa: B901
finally:
return # noqa: B012, SIM107
def fn(t):
gen = whoo()
next(gen)
try:
gen.throw(ValueError)
except StopIteration as e:
assert len(e.args) == 0
except Exception as e:
raise AssertionError from e
return t.sin()
self._compile_check(fn)
instantiate_parametrized_tests(GeneratorTests) instantiate_parametrized_tests(GeneratorTests)
instantiate_parametrized_tests(TestGeneratorSend) instantiate_parametrized_tests(TestGeneratorSend)

View File

@ -3982,7 +3982,13 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
): ):
assert isinstance(self, InliningGeneratorInstructionTranslator) assert isinstance(self, InliningGeneratorInstructionTranslator)
# When the generator returns None, we raise StopIteration # When the generator returns None, we raise StopIteration
exc.raise_observed_exception(StopIteration, self) args = []
if not (
isinstance(self.symbolic_result, ConstantVariable)
and self.symbolic_result.value is None
):
args = [self.symbolic_result]
exc.raise_observed_exception(StopIteration, self, args=args)
else: else:
return self.symbolic_result return self.symbolic_result
else: else: