mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[generator] Raise StopIteration(value)
with value from the return stmt (#157152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157152 Approved by: https://github.com/zou3519 ghstack dependencies: #157148
This commit is contained in:
committed by
PyTorch MergeBot
parent
831e85104a
commit
d387a48c38
@ -1,5 +1,5 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_generators.py b/test/dynamo/cpython/3_13/test_generators.py
|
||||
index e48d79d34f4..a48da0914b9 100644
|
||||
index 515fe7407f1..a48da0914b9 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_generators.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_generators.py
|
||||
@@ -1,3 +1,56 @@
|
||||
@ -105,7 +105,8 @@ index e48d79d34f4..a48da0914b9 100644
|
||||
+ return self.val
|
||||
+
|
||||
+ # No __iter__ method
|
||||
+
|
||||
|
||||
-class ModifyUnderlyingIterableTest(unittest.TestCase):
|
||||
+ class C:
|
||||
+
|
||||
+ def __iter__(self):
|
||||
@ -113,8 +114,7 @@ index e48d79d34f4..a48da0914b9 100644
|
||||
+
|
||||
+ self.assertEqual([1,2], list(i for i in C()))
|
||||
+
|
||||
|
||||
-class ModifyUnderlyingIterableTest(unittest.TestCase):
|
||||
+
|
||||
+class ModifyUnderlyingIterableTest(__TestCase):
|
||||
iterables = [
|
||||
range(0),
|
||||
@ -137,99 +137,16 @@ index e48d79d34f4..a48da0914b9 100644
|
||||
|
||||
def test_close_no_return_value(self):
|
||||
def f():
|
||||
@@ -630,90 +706,7 @@ class GeneratorCloseTest(unittest.TestCase):
|
||||
@@ -630,7 +706,7 @@ class GeneratorCloseTest(unittest.TestCase):
|
||||
self.assertIsNone(f_wr())
|
||||
|
||||
|
||||
-# See https://github.com/python/cpython/issues/125723
|
||||
-class GeneratorDeallocTest(unittest.TestCase):
|
||||
- def test_frame_outlives_generator(self):
|
||||
- def g1():
|
||||
- a = 42
|
||||
- yield sys._getframe()
|
||||
-
|
||||
- def g2():
|
||||
- a = 42
|
||||
- yield
|
||||
-
|
||||
- def g3(obj):
|
||||
- a = 42
|
||||
- obj.frame = sys._getframe()
|
||||
- yield
|
||||
-
|
||||
- class ObjectWithFrame():
|
||||
- def __init__(self):
|
||||
- self.frame = None
|
||||
-
|
||||
- def get_frame(index):
|
||||
- if index == 1:
|
||||
- return next(g1())
|
||||
- elif index == 2:
|
||||
- gen = g2()
|
||||
- next(gen)
|
||||
- return gen.gi_frame
|
||||
- elif index == 3:
|
||||
- obj = ObjectWithFrame()
|
||||
- next(g3(obj))
|
||||
- return obj.frame
|
||||
- else:
|
||||
- return None
|
||||
-
|
||||
- for index in (1, 2, 3):
|
||||
- with self.subTest(index=index):
|
||||
- frame = get_frame(index)
|
||||
- frame_locals = frame.f_locals
|
||||
- self.assertIn('a', frame_locals)
|
||||
- self.assertEqual(frame_locals['a'], 42)
|
||||
-
|
||||
- def test_frame_locals_outlive_generator(self):
|
||||
- frame_locals1 = None
|
||||
-
|
||||
- def g1():
|
||||
- nonlocal frame_locals1
|
||||
- frame_locals1 = sys._getframe().f_locals
|
||||
- a = 42
|
||||
- yield
|
||||
-
|
||||
- def g2():
|
||||
- a = 42
|
||||
- yield sys._getframe().f_locals
|
||||
-
|
||||
- def get_frame_locals(index):
|
||||
- if index == 1:
|
||||
- nonlocal frame_locals1
|
||||
- next(g1())
|
||||
- return frame_locals1
|
||||
- if index == 2:
|
||||
- return next(g2())
|
||||
- else:
|
||||
- return None
|
||||
-
|
||||
- for index in (1, 2):
|
||||
- with self.subTest(index=index):
|
||||
- frame_locals = get_frame_locals(index)
|
||||
- self.assertIn('a', frame_locals)
|
||||
- self.assertEqual(frame_locals['a'], 42)
|
||||
-
|
||||
- def test_frame_locals_outlive_generator_with_exec(self):
|
||||
- def g():
|
||||
- a = 42
|
||||
- yield locals(), sys._getframe().f_locals
|
||||
-
|
||||
- locals_ = {'g': g}
|
||||
- for i in range(10):
|
||||
- exec("snapshot, live_locals = next(g())", locals=locals_)
|
||||
- for l in (locals_['snapshot'], locals_['live_locals']):
|
||||
- self.assertIn('a', l)
|
||||
- self.assertEqual(l['a'], 42)
|
||||
-
|
||||
-
|
||||
-class GeneratorThrowTest(unittest.TestCase):
|
||||
+class GeneratorThrowTest(__TestCase):
|
||||
|
||||
def test_exception_context_with_yield(self):
|
||||
def f():
|
||||
@@ -812,7 +805,7 @@ class GeneratorThrowTest(unittest.TestCase):
|
||||
@@ -729,7 +805,7 @@ class GeneratorThrowTest(unittest.TestCase):
|
||||
gen.throw(ValueError)
|
||||
|
||||
|
||||
@ -238,7 +155,7 @@ index e48d79d34f4..a48da0914b9 100644
|
||||
|
||||
def check_stack_names(self, frame, expected):
|
||||
names = []
|
||||
@@ -861,7 +854,7 @@ class GeneratorStackTraceTest(unittest.TestCase):
|
||||
@@ -778,7 +854,7 @@ class GeneratorStackTraceTest(unittest.TestCase):
|
||||
self.check_yield_from_example(call_throw)
|
||||
|
||||
|
||||
@ -247,7 +164,7 @@ index e48d79d34f4..a48da0914b9 100644
|
||||
def test_generator_gi_yieldfrom(self):
|
||||
def a():
|
||||
self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_RUNNING)
|
||||
@@ -2752,21 +2745,27 @@ test_generators just happened to be the test that drew these out.
|
||||
@@ -2669,21 +2745,27 @@ test_generators just happened to be the test that drew these out.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -1515,6 +1515,76 @@ class TestGeneratorThrow(GeneratorTestsBase):
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
def test_return_const_value_in_except_and_finally(self):
|
||||
def whoo():
|
||||
try:
|
||||
yield 1
|
||||
except ValueError:
|
||||
return 2 # noqa: B901
|
||||
finally:
|
||||
return 3 # noqa: B012, SIM107, B901
|
||||
|
||||
def fn(t):
|
||||
gen = whoo()
|
||||
next(gen)
|
||||
try:
|
||||
gen.throw(ValueError)
|
||||
except StopIteration as e:
|
||||
assert e.args[0] == 3
|
||||
except Exception as e:
|
||||
raise AssertionError from e
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
def test_return_value_in_except_and_finally(self):
|
||||
class Foo:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
def whoo():
|
||||
try:
|
||||
yield 1
|
||||
except ValueError:
|
||||
return Foo(2) # noqa: B901
|
||||
finally:
|
||||
return Foo(3) # noqa: B012, SIM107, B901
|
||||
|
||||
def fn(t):
|
||||
gen = whoo()
|
||||
next(gen)
|
||||
try:
|
||||
gen.throw(ValueError)
|
||||
except StopIteration as e:
|
||||
assert e.args[0].x == 3
|
||||
except Exception as e:
|
||||
raise AssertionError from e
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
def test_return_None_in_except_and_finally(self):
|
||||
def whoo():
|
||||
try:
|
||||
yield 1
|
||||
except ValueError:
|
||||
return 2 # noqa: B901
|
||||
finally:
|
||||
return # noqa: B012, SIM107
|
||||
|
||||
def fn(t):
|
||||
gen = whoo()
|
||||
next(gen)
|
||||
try:
|
||||
gen.throw(ValueError)
|
||||
except StopIteration as e:
|
||||
assert len(e.args) == 0
|
||||
except Exception as e:
|
||||
raise AssertionError from e
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(GeneratorTests)
|
||||
instantiate_parametrized_tests(TestGeneratorSend)
|
||||
|
@ -3982,7 +3982,13 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
):
|
||||
assert isinstance(self, InliningGeneratorInstructionTranslator)
|
||||
# When the generator returns None, we raise StopIteration
|
||||
exc.raise_observed_exception(StopIteration, self)
|
||||
args = []
|
||||
if not (
|
||||
isinstance(self.symbolic_result, ConstantVariable)
|
||||
and self.symbolic_result.value is None
|
||||
):
|
||||
args = [self.symbolic_result]
|
||||
exc.raise_observed_exception(StopIteration, self, args=args)
|
||||
else:
|
||||
return self.symbolic_result
|
||||
else:
|
||||
|
Reference in New Issue
Block a user