mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[generator] Raise StopIteration(value)
with value from the return stmt (#157152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157152 Approved by: https://github.com/zou3519 ghstack dependencies: #157148
This commit is contained in:
committed by
PyTorch MergeBot
parent
831e85104a
commit
d387a48c38
@ -1,5 +1,5 @@
|
|||||||
diff --git a/test/dynamo/cpython/3_13/test_generators.py b/test/dynamo/cpython/3_13/test_generators.py
|
diff --git a/test/dynamo/cpython/3_13/test_generators.py b/test/dynamo/cpython/3_13/test_generators.py
|
||||||
index e48d79d34f4..a48da0914b9 100644
|
index 515fe7407f1..a48da0914b9 100644
|
||||||
--- a/test/dynamo/cpython/3_13/test_generators.py
|
--- a/test/dynamo/cpython/3_13/test_generators.py
|
||||||
+++ b/test/dynamo/cpython/3_13/test_generators.py
|
+++ b/test/dynamo/cpython/3_13/test_generators.py
|
||||||
@@ -1,3 +1,56 @@
|
@@ -1,3 +1,56 @@
|
||||||
@ -105,7 +105,8 @@ index e48d79d34f4..a48da0914b9 100644
|
|||||||
+ return self.val
|
+ return self.val
|
||||||
+
|
+
|
||||||
+ # No __iter__ method
|
+ # No __iter__ method
|
||||||
+
|
|
||||||
|
-class ModifyUnderlyingIterableTest(unittest.TestCase):
|
||||||
+ class C:
|
+ class C:
|
||||||
+
|
+
|
||||||
+ def __iter__(self):
|
+ def __iter__(self):
|
||||||
@ -113,8 +114,7 @@ index e48d79d34f4..a48da0914b9 100644
|
|||||||
+
|
+
|
||||||
+ self.assertEqual([1,2], list(i for i in C()))
|
+ self.assertEqual([1,2], list(i for i in C()))
|
||||||
+
|
+
|
||||||
|
+
|
||||||
-class ModifyUnderlyingIterableTest(unittest.TestCase):
|
|
||||||
+class ModifyUnderlyingIterableTest(__TestCase):
|
+class ModifyUnderlyingIterableTest(__TestCase):
|
||||||
iterables = [
|
iterables = [
|
||||||
range(0),
|
range(0),
|
||||||
@ -137,99 +137,16 @@ index e48d79d34f4..a48da0914b9 100644
|
|||||||
|
|
||||||
def test_close_no_return_value(self):
|
def test_close_no_return_value(self):
|
||||||
def f():
|
def f():
|
||||||
@@ -630,90 +706,7 @@ class GeneratorCloseTest(unittest.TestCase):
|
@@ -630,7 +706,7 @@ class GeneratorCloseTest(unittest.TestCase):
|
||||||
self.assertIsNone(f_wr())
|
self.assertIsNone(f_wr())
|
||||||
|
|
||||||
|
|
||||||
-# See https://github.com/python/cpython/issues/125723
|
|
||||||
-class GeneratorDeallocTest(unittest.TestCase):
|
|
||||||
- def test_frame_outlives_generator(self):
|
|
||||||
- def g1():
|
|
||||||
- a = 42
|
|
||||||
- yield sys._getframe()
|
|
||||||
-
|
|
||||||
- def g2():
|
|
||||||
- a = 42
|
|
||||||
- yield
|
|
||||||
-
|
|
||||||
- def g3(obj):
|
|
||||||
- a = 42
|
|
||||||
- obj.frame = sys._getframe()
|
|
||||||
- yield
|
|
||||||
-
|
|
||||||
- class ObjectWithFrame():
|
|
||||||
- def __init__(self):
|
|
||||||
- self.frame = None
|
|
||||||
-
|
|
||||||
- def get_frame(index):
|
|
||||||
- if index == 1:
|
|
||||||
- return next(g1())
|
|
||||||
- elif index == 2:
|
|
||||||
- gen = g2()
|
|
||||||
- next(gen)
|
|
||||||
- return gen.gi_frame
|
|
||||||
- elif index == 3:
|
|
||||||
- obj = ObjectWithFrame()
|
|
||||||
- next(g3(obj))
|
|
||||||
- return obj.frame
|
|
||||||
- else:
|
|
||||||
- return None
|
|
||||||
-
|
|
||||||
- for index in (1, 2, 3):
|
|
||||||
- with self.subTest(index=index):
|
|
||||||
- frame = get_frame(index)
|
|
||||||
- frame_locals = frame.f_locals
|
|
||||||
- self.assertIn('a', frame_locals)
|
|
||||||
- self.assertEqual(frame_locals['a'], 42)
|
|
||||||
-
|
|
||||||
- def test_frame_locals_outlive_generator(self):
|
|
||||||
- frame_locals1 = None
|
|
||||||
-
|
|
||||||
- def g1():
|
|
||||||
- nonlocal frame_locals1
|
|
||||||
- frame_locals1 = sys._getframe().f_locals
|
|
||||||
- a = 42
|
|
||||||
- yield
|
|
||||||
-
|
|
||||||
- def g2():
|
|
||||||
- a = 42
|
|
||||||
- yield sys._getframe().f_locals
|
|
||||||
-
|
|
||||||
- def get_frame_locals(index):
|
|
||||||
- if index == 1:
|
|
||||||
- nonlocal frame_locals1
|
|
||||||
- next(g1())
|
|
||||||
- return frame_locals1
|
|
||||||
- if index == 2:
|
|
||||||
- return next(g2())
|
|
||||||
- else:
|
|
||||||
- return None
|
|
||||||
-
|
|
||||||
- for index in (1, 2):
|
|
||||||
- with self.subTest(index=index):
|
|
||||||
- frame_locals = get_frame_locals(index)
|
|
||||||
- self.assertIn('a', frame_locals)
|
|
||||||
- self.assertEqual(frame_locals['a'], 42)
|
|
||||||
-
|
|
||||||
- def test_frame_locals_outlive_generator_with_exec(self):
|
|
||||||
- def g():
|
|
||||||
- a = 42
|
|
||||||
- yield locals(), sys._getframe().f_locals
|
|
||||||
-
|
|
||||||
- locals_ = {'g': g}
|
|
||||||
- for i in range(10):
|
|
||||||
- exec("snapshot, live_locals = next(g())", locals=locals_)
|
|
||||||
- for l in (locals_['snapshot'], locals_['live_locals']):
|
|
||||||
- self.assertIn('a', l)
|
|
||||||
- self.assertEqual(l['a'], 42)
|
|
||||||
-
|
|
||||||
-
|
|
||||||
-class GeneratorThrowTest(unittest.TestCase):
|
-class GeneratorThrowTest(unittest.TestCase):
|
||||||
+class GeneratorThrowTest(__TestCase):
|
+class GeneratorThrowTest(__TestCase):
|
||||||
|
|
||||||
def test_exception_context_with_yield(self):
|
def test_exception_context_with_yield(self):
|
||||||
def f():
|
def f():
|
||||||
@@ -812,7 +805,7 @@ class GeneratorThrowTest(unittest.TestCase):
|
@@ -729,7 +805,7 @@ class GeneratorThrowTest(unittest.TestCase):
|
||||||
gen.throw(ValueError)
|
gen.throw(ValueError)
|
||||||
|
|
||||||
|
|
||||||
@ -238,7 +155,7 @@ index e48d79d34f4..a48da0914b9 100644
|
|||||||
|
|
||||||
def check_stack_names(self, frame, expected):
|
def check_stack_names(self, frame, expected):
|
||||||
names = []
|
names = []
|
||||||
@@ -861,7 +854,7 @@ class GeneratorStackTraceTest(unittest.TestCase):
|
@@ -778,7 +854,7 @@ class GeneratorStackTraceTest(unittest.TestCase):
|
||||||
self.check_yield_from_example(call_throw)
|
self.check_yield_from_example(call_throw)
|
||||||
|
|
||||||
|
|
||||||
@ -247,7 +164,7 @@ index e48d79d34f4..a48da0914b9 100644
|
|||||||
def test_generator_gi_yieldfrom(self):
|
def test_generator_gi_yieldfrom(self):
|
||||||
def a():
|
def a():
|
||||||
self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_RUNNING)
|
self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_RUNNING)
|
||||||
@@ -2752,21 +2745,27 @@ test_generators just happened to be the test that drew these out.
|
@@ -2669,21 +2745,27 @@ test_generators just happened to be the test that drew these out.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -1515,6 +1515,76 @@ class TestGeneratorThrow(GeneratorTestsBase):
|
|||||||
|
|
||||||
self._compile_check(fn)
|
self._compile_check(fn)
|
||||||
|
|
||||||
|
def test_return_const_value_in_except_and_finally(self):
|
||||||
|
def whoo():
|
||||||
|
try:
|
||||||
|
yield 1
|
||||||
|
except ValueError:
|
||||||
|
return 2 # noqa: B901
|
||||||
|
finally:
|
||||||
|
return 3 # noqa: B012, SIM107, B901
|
||||||
|
|
||||||
|
def fn(t):
|
||||||
|
gen = whoo()
|
||||||
|
next(gen)
|
||||||
|
try:
|
||||||
|
gen.throw(ValueError)
|
||||||
|
except StopIteration as e:
|
||||||
|
assert e.args[0] == 3
|
||||||
|
except Exception as e:
|
||||||
|
raise AssertionError from e
|
||||||
|
return t.sin()
|
||||||
|
|
||||||
|
self._compile_check(fn)
|
||||||
|
|
||||||
|
def test_return_value_in_except_and_finally(self):
|
||||||
|
class Foo:
|
||||||
|
def __init__(self, x):
|
||||||
|
self.x = x
|
||||||
|
|
||||||
|
def whoo():
|
||||||
|
try:
|
||||||
|
yield 1
|
||||||
|
except ValueError:
|
||||||
|
return Foo(2) # noqa: B901
|
||||||
|
finally:
|
||||||
|
return Foo(3) # noqa: B012, SIM107, B901
|
||||||
|
|
||||||
|
def fn(t):
|
||||||
|
gen = whoo()
|
||||||
|
next(gen)
|
||||||
|
try:
|
||||||
|
gen.throw(ValueError)
|
||||||
|
except StopIteration as e:
|
||||||
|
assert e.args[0].x == 3
|
||||||
|
except Exception as e:
|
||||||
|
raise AssertionError from e
|
||||||
|
return t.sin()
|
||||||
|
|
||||||
|
self._compile_check(fn)
|
||||||
|
|
||||||
|
def test_return_None_in_except_and_finally(self):
|
||||||
|
def whoo():
|
||||||
|
try:
|
||||||
|
yield 1
|
||||||
|
except ValueError:
|
||||||
|
return 2 # noqa: B901
|
||||||
|
finally:
|
||||||
|
return # noqa: B012, SIM107
|
||||||
|
|
||||||
|
def fn(t):
|
||||||
|
gen = whoo()
|
||||||
|
next(gen)
|
||||||
|
try:
|
||||||
|
gen.throw(ValueError)
|
||||||
|
except StopIteration as e:
|
||||||
|
assert len(e.args) == 0
|
||||||
|
except Exception as e:
|
||||||
|
raise AssertionError from e
|
||||||
|
return t.sin()
|
||||||
|
|
||||||
|
self._compile_check(fn)
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(GeneratorTests)
|
instantiate_parametrized_tests(GeneratorTests)
|
||||||
instantiate_parametrized_tests(TestGeneratorSend)
|
instantiate_parametrized_tests(TestGeneratorSend)
|
||||||
|
@ -3982,7 +3982,13 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||||||
):
|
):
|
||||||
assert isinstance(self, InliningGeneratorInstructionTranslator)
|
assert isinstance(self, InliningGeneratorInstructionTranslator)
|
||||||
# When the generator returns None, we raise StopIteration
|
# When the generator returns None, we raise StopIteration
|
||||||
exc.raise_observed_exception(StopIteration, self)
|
args = []
|
||||||
|
if not (
|
||||||
|
isinstance(self.symbolic_result, ConstantVariable)
|
||||||
|
and self.symbolic_result.value is None
|
||||||
|
):
|
||||||
|
args = [self.symbolic_result]
|
||||||
|
exc.raise_observed_exception(StopIteration, self, args=args)
|
||||||
else:
|
else:
|
||||||
return self.symbolic_result
|
return self.symbolic_result
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user