mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150787 Approved by: https://github.com/zou3519
1493 lines
38 KiB
Python
1493 lines
38 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import itertools
|
|
import sys
|
|
import unittest
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.exc import Unsupported
|
|
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
|
|
from torch._dynamo.utils import counters
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
make_dynamo_test,
|
|
parametrize,
|
|
)
|
|
|
|
|
|
class GeneratorTestsBase(torch._dynamo.test_case.TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._old = torch._dynamo.config.enable_faithful_generator_behavior
|
|
torch._dynamo.config.enable_faithful_generator_behavior = True
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
torch._dynamo.config.enable_faithful_generator_behavior = self._old
|
|
|
|
def _compile_check(self, fn, args=None, fullgraph=True):
|
|
eager = EagerAndRecordGraphs()
|
|
if args is None:
|
|
args = (torch.randn(2),)
|
|
r = torch.compile(fn, backend=eager, fullgraph=fullgraph)(*args)
|
|
self.assertGreater(len(eager.graphs), 0)
|
|
return r
|
|
|
|
|
|
class GeneratorTests(GeneratorTestsBase):
|
|
def test_generator_simple(self):
|
|
def whoo():
|
|
yield 1
|
|
yield 2
|
|
yield 3
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = whoo()
|
|
t = t + next(gen)
|
|
t = t + next(gen)
|
|
t = t + next(gen)
|
|
return t
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, t + 6)
|
|
|
|
def test_infinite_generator(self):
|
|
def whoo():
|
|
i = 0
|
|
while True:
|
|
yield i
|
|
i += 1
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = whoo()
|
|
t = t + next(gen)
|
|
t = t + next(gen)
|
|
t = t + next(gen)
|
|
return t
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, t + 3)
|
|
|
|
def test_infinite_generator_2(self):
|
|
def whoo(t):
|
|
i = 0
|
|
while True:
|
|
yield t + i
|
|
i += 1
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
return list(zip(range(3), whoo(t)))
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, list(zip(range(3), whoo(t))))
|
|
|
|
def test_infinite_generator_3(self):
|
|
def whoo(i):
|
|
while True:
|
|
yield i
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
return list(zip(range(3), whoo(1))), t.sin()
|
|
|
|
t = torch.randn(2)
|
|
y, _ = fn(t)
|
|
self.assertEqual(y, list(zip(range(3), whoo(1))))
|
|
|
|
def test_graph_break_in_generator(self):
|
|
def whoo():
|
|
yield 1
|
|
torch._dynamo.graph_break()
|
|
yield 2
|
|
|
|
eager = EagerAndRecordGraphs()
|
|
|
|
@torch.compile(backend=eager, fullgraph=False)
|
|
def fn(t):
|
|
gen = whoo()
|
|
s = next(gen)
|
|
s += next(gen)
|
|
return t + s
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, t + 3)
|
|
self.assertEqual(len(eager.graphs), 0)
|
|
|
|
def test_graph_break_in_generator_2(self):
|
|
def whoo(x):
|
|
yield x.sin()
|
|
torch._dynamo.graph_break()
|
|
yield x.cos()
|
|
|
|
def call_whoo(x):
|
|
gen = whoo(x)
|
|
sin = next(gen)
|
|
cos = next(gen)
|
|
return sin, cos
|
|
|
|
eager = EagerAndRecordGraphs()
|
|
|
|
@torch.compile(backend=eager, fullgraph=False)
|
|
def fn(t):
|
|
sin, cos = call_whoo(t)
|
|
return sin + cos
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, t.sin() + t.cos())
|
|
self.assertEqual(len(eager.graphs), 1)
|
|
self.assertExpectedInline(
|
|
normalize_gm(eager.graphs[0].print_readable(False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_stack0_0_: "f32[2]", L_stack0_1_: "f32[2]"):
|
|
l_stack0_0_ = L_stack0_0_
|
|
l_stack0_1_ = L_stack0_1_
|
|
|
|
add: "f32[2]" = l_stack0_0_ + l_stack0_1_; l_stack0_0_ = l_stack0_1_ = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
def test_reconstruct_generator_with_local_var_mutation(self):
|
|
def whoo(t):
|
|
x = 0
|
|
yield t.sin() + x
|
|
x += 1
|
|
yield t.cos() + x
|
|
x += 1
|
|
yield t.tan() + x
|
|
|
|
@torch.compile(backend="eager", fullgraph=False)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
next(gen)
|
|
return t.sin(), gen
|
|
|
|
t = torch.randn(2)
|
|
y, g = fn(t)
|
|
self.assertEqual(y, t.sin())
|
|
self.assertEqual(list(g), [t.cos() + 1, t.tan() + 2])
|
|
|
|
def test_reconstruct_generator_with_dict_mutation(self):
|
|
counters.clear()
|
|
|
|
def whoo(t, d):
|
|
d[2] = t
|
|
yield t.sin()
|
|
yield t.cos()
|
|
d[3] = t + 1
|
|
yield t.tan()
|
|
|
|
@torch.compile(backend="eager", fullgraph=False)
|
|
def fn(t, d):
|
|
gen = whoo(t, d)
|
|
next(gen)
|
|
return t.sin(), whoo(t, d)
|
|
|
|
t = torch.randn(2)
|
|
d = {1: t}
|
|
fn(t, d)
|
|
self.assertEqual(len(counters["unimplemented"]), 1)
|
|
self.assertEqual(
|
|
dict(counters["unimplemented"]),
|
|
{
|
|
"Cannot reconstruct a generator with variable mutations. "
|
|
"Dynamo needs to fully exhaust the generator, which may cause "
|
|
"unintended variable modifications.": 1
|
|
},
|
|
)
|
|
|
|
def test_reconstruct_generator_with_dict_mutation_before(self):
|
|
def whoo(t, d):
|
|
d[2] = t
|
|
yield t.sin()
|
|
yield t.cos()
|
|
yield t.tan()
|
|
|
|
@torch.compile(backend="eager", fullgraph=False)
|
|
def fn(t, d):
|
|
gen = whoo(t, d)
|
|
next(gen)
|
|
return t.sin(), gen
|
|
|
|
t = torch.randn(2)
|
|
d = {1: t}
|
|
y, g = fn(t, d)
|
|
self.assertEqual(y, t.sin())
|
|
self.assertEqual(list(g), [t.cos(), t.tan()])
|
|
self.assertEqual(d, {1: t, 2: t})
|
|
|
|
def test_reconstruct_generator_with_object_mutation(self):
|
|
class Counter:
|
|
def __init__(self):
|
|
self.x = 0
|
|
|
|
def incr(self):
|
|
self.x += 1
|
|
|
|
def whoo(t, c):
|
|
c.incr()
|
|
yield t.sin()
|
|
yield t.cos()
|
|
c.incr()
|
|
yield t.tan()
|
|
|
|
@torch.compile(backend="eager", fullgraph=False)
|
|
def fn(t, c):
|
|
gen = whoo(t, c)
|
|
next(gen)
|
|
return t.sin(), gen
|
|
|
|
t = torch.randn(2)
|
|
c = Counter()
|
|
fn(t, c)
|
|
self.assertEqual(len(counters["unimplemented"]), 1)
|
|
self.assertEqual(
|
|
dict(counters["unimplemented"]),
|
|
{
|
|
"Cannot reconstruct a generator with variable mutations. "
|
|
"Dynamo needs to fully exhaust the generator, which may cause "
|
|
"unintended variable modifications.": 1
|
|
},
|
|
)
|
|
|
|
def test_reconstruct_generator_with_object_mutation_before(self):
|
|
class Counter:
|
|
def __init__(self):
|
|
self.x = 0
|
|
|
|
def incr(self):
|
|
self.x += 1
|
|
|
|
def whoo(t, c):
|
|
c.incr()
|
|
yield t.sin()
|
|
yield t.cos()
|
|
yield t.tan()
|
|
|
|
@torch.compile(backend="eager", fullgraph=False)
|
|
def fn(t, c):
|
|
gen = whoo(t, c)
|
|
next(gen)
|
|
# We should be able to reconstruct the generator as there's no object
|
|
# mutation after the first yield
|
|
return t.sin(), gen
|
|
|
|
t = torch.randn(2)
|
|
c = Counter()
|
|
y, g = fn(t, c)
|
|
self.assertEqual(c.x, 1)
|
|
self.assertEqual(y, t.sin())
|
|
self.assertEqual(list(g), [t.cos(), t.tan()])
|
|
|
|
def test_graph_break_and_reconstruct_generator(self):
|
|
def whoo(t):
|
|
yield t.sin()
|
|
yield t.cos()
|
|
yield t.tan()
|
|
|
|
def g(t):
|
|
torch._dynamo.graph_break()
|
|
|
|
@torch.compile(backend="eager", fullgraph=False)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
next(gen)
|
|
g(t)
|
|
return t.sin(), list(gen)
|
|
|
|
t = torch.randn(2)
|
|
y, gen = fn(t)
|
|
self.assertEqual(y, t.sin())
|
|
self.assertEqual(list(gen), [t.cos(), t.tan()])
|
|
|
|
def test_graph_break_in_generator_while_reconstructing(self):
|
|
counters.clear()
|
|
|
|
def whoo():
|
|
yield 1
|
|
torch._dynamo.graph_break()
|
|
yield 2
|
|
|
|
eager = EagerAndRecordGraphs()
|
|
|
|
@torch.compile(backend=eager, fullgraph=False)
|
|
def fn(t):
|
|
gen = whoo()
|
|
s = next(gen)
|
|
torch._dynamo.graph_break()
|
|
s += next(gen)
|
|
return t + s
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, t + 3)
|
|
self.assertEqual(len(eager.graphs), 0)
|
|
|
|
def test_generator_as_argument(self):
|
|
# The inline tracer needs to be kept in sync if an already advanced generator
|
|
# is given to a compiled function.
|
|
def whoo():
|
|
yield 1
|
|
yield 2
|
|
yield 3
|
|
|
|
eager = EagerAndRecordGraphs()
|
|
|
|
@torch.compile(backend=eager, fullgraph=True)
|
|
def fn(t, ctx):
|
|
return t + next(ctx)
|
|
|
|
t = torch.randn(2)
|
|
ctx = whoo()
|
|
next(ctx)
|
|
with self.assertRaisesRegex(
|
|
Unsupported, "Generator as graph argument is not supported"
|
|
):
|
|
fn(t, ctx)
|
|
|
|
def test_generator_as_argument_2(self):
|
|
def whoo(x):
|
|
yield x.sin()
|
|
yield x.cos()
|
|
|
|
eager = EagerAndRecordGraphs()
|
|
|
|
@torch.compile(backend=eager, fullgraph=True)
|
|
def fn(t, ctx):
|
|
return t + next(ctx)
|
|
|
|
t = torch.randn(2)
|
|
ctx = whoo(t)
|
|
next(ctx)
|
|
with self.assertRaisesRegex(
|
|
Unsupported, "Generator as graph argument is not supported"
|
|
):
|
|
fn(t, ctx)
|
|
|
|
def test_generator_as_argument_3(self):
|
|
# The inline tracer needs to be kept in sync if an already advanced generator
|
|
# is given to a compiled function.
|
|
def whoo():
|
|
yield 1
|
|
yield 2
|
|
yield 3
|
|
|
|
eager = EagerAndRecordGraphs()
|
|
|
|
@torch.compile(backend=eager, fullgraph=True)
|
|
def fn(t, ctx):
|
|
return t + next(ctx)
|
|
|
|
t = torch.randn(2)
|
|
ctx = whoo()
|
|
with self.assertRaisesRegex(
|
|
Unsupported, "Generator as graph argument is not supported"
|
|
):
|
|
fn(t, ctx)
|
|
|
|
def test_generator_as_argument_4(self):
|
|
def whoo(x):
|
|
yield x.sin()
|
|
yield x.cos()
|
|
|
|
eager = EagerAndRecordGraphs()
|
|
|
|
@torch.compile(backend=eager, fullgraph=True)
|
|
def fn(t, ctx):
|
|
return t + next(ctx)
|
|
|
|
t = torch.randn(2)
|
|
ctx = whoo(t)
|
|
with self.assertRaisesRegex(
|
|
Unsupported, "Generator as graph argument is not supported"
|
|
):
|
|
fn(t, ctx)
|
|
|
|
def test_islice_chain(self):
|
|
eager = EagerAndRecordGraphs()
|
|
|
|
@torch.compile(backend=eager, fullgraph=True)
|
|
def fn(t):
|
|
tmp1 = [t + 1, t + 2]
|
|
tmp2 = [t + 3, t + 4]
|
|
return list(itertools.chain(tmp1, tmp2))
|
|
|
|
t = torch.tensor([1.0])
|
|
y = fn(t)
|
|
self.assertEqual(y, [t + 1, t + 2, t + 3, t + 4])
|
|
|
|
def test_zip_generator(self):
|
|
def whoo(t):
|
|
yield t + 1
|
|
yield t + 2
|
|
yield t + 3
|
|
|
|
def fn(t):
|
|
return zip(range(3), whoo(t)), t.sin()
|
|
|
|
t = torch.randn(2)
|
|
z, _ = self._compile_check(fn, args=(t,))
|
|
self.assertEqual(list(z), list(zip(range(3), whoo(t))))
|
|
|
|
@unittest.expectedFailure
|
|
def test_zip_generator_2(self):
|
|
def bar(t, i):
|
|
return t + i
|
|
|
|
def whoo(t):
|
|
yield bar(t, 1)
|
|
yield bar(t, 2)
|
|
yield bar(t, 3)
|
|
|
|
def fn(t):
|
|
return zip(range(3), whoo(t))
|
|
|
|
t = torch.randn(3)
|
|
y = self._compile_check(fn, args=(t,), fullgraph=False)
|
|
expected = list(zip(range(3), whoo(t)))
|
|
self.assertEqual(expected, list(y))
|
|
|
|
def test_zip_subgenerator(self):
|
|
def subgen(t):
|
|
yield t + 1
|
|
yield t + 2
|
|
|
|
def whoo(t):
|
|
yield from subgen(t)
|
|
yield t + 3
|
|
|
|
def fn(t):
|
|
return zip(range(3), whoo(t)), t.sin()
|
|
|
|
t = torch.randn(2)
|
|
z, _ = self._compile_check(fn, args=(t,))
|
|
self.assertEqual(list(z), list(zip(range(3), whoo(t))))
|
|
|
|
def test_list_zip_generator(self):
|
|
def whoo(t):
|
|
yield t + 1
|
|
yield t + 2
|
|
yield t + 3
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
return list(zip(range(3), whoo(t)))
|
|
|
|
t = torch.randn(3)
|
|
y = fn(t)
|
|
expected = list(zip(range(3), whoo(t)))
|
|
self.assertEqual(expected, y)
|
|
|
|
def test_zip_infinite_generator(self):
|
|
def whoo(t):
|
|
i = 0
|
|
while True:
|
|
yield t + i
|
|
i += 1
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
return list(zip(range(3), whoo(t)))
|
|
|
|
t = torch.randn(3)
|
|
y = fn(t)
|
|
expected = list(zip(range(3), whoo(t)))
|
|
self.assertEqual(expected, y)
|
|
|
|
@parametrize("container", [list, tuple, dict, OrderedDict])
|
|
def test_dict_tuple_list_generator(self, container):
|
|
def whoo(t):
|
|
yield 1, t + 1
|
|
yield 2, t + 2
|
|
yield 3, t + 3
|
|
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
return container(gen)
|
|
|
|
t = torch.randn(2)
|
|
expected = fn(t)
|
|
got = torch.compile(backend="eager", fullgraph=True)(fn)(t)
|
|
self.assertEqual(expected, got)
|
|
|
|
def test_return_generator(self):
|
|
def whoo(t):
|
|
yield t + 1
|
|
yield t + 2
|
|
yield t + 3
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
return gen
|
|
|
|
t = torch.tensor([1.0])
|
|
gen = fn(t)
|
|
self.assertEqual(list(gen), [t + 1, t + 2, t + 3])
|
|
|
|
def test_return_tuple_generator(self):
|
|
def whoo(t):
|
|
yield t.sin()
|
|
yield t.cos()
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
g1, g2 = whoo(t), whoo(t + 1)
|
|
return (g1, g2), t.sin()
|
|
|
|
t = torch.randn(2)
|
|
(g1, g2), _ = fn(t)
|
|
self.assertEqual(list(g1), [t.sin(), t.cos()])
|
|
self.assertEqual(list(g2), [(t + 1).sin(), (t + 1).cos()])
|
|
|
|
def test_return_advanced_generator(self):
|
|
def whoo(t):
|
|
yield t + 1
|
|
yield t + 2
|
|
yield t + 3
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
next(gen)
|
|
return gen
|
|
|
|
t = torch.tensor([1.0])
|
|
gen = fn(t)
|
|
self.assertEqual(list(gen), [t + 2, t + 3])
|
|
|
|
def test_return_exhaust_generator(self):
|
|
def whoo(t):
|
|
yield t + 1
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
next(gen)
|
|
return gen
|
|
|
|
t = torch.tensor([1.0])
|
|
gen = fn(t)
|
|
with self.assertRaises(StopIteration):
|
|
next(gen)
|
|
|
|
@unittest.expectedFailure
|
|
def test_reconstruct_generator_tensor_mutation(self):
|
|
def whoo(t):
|
|
yield t.sin_()
|
|
yield t.cos_()
|
|
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
return gen
|
|
|
|
with self.assertRaisesRegex(
|
|
Unsupported,
|
|
"Cannot reconstruct a generator with variable mutations",
|
|
):
|
|
self._compile_check(fn)
|
|
|
|
def test_subgenerator(self):
|
|
def subgen(t):
|
|
yield t + 1
|
|
yield t + 2
|
|
|
|
def main_gen(t):
|
|
yield from subgen(t)
|
|
yield t + 3
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = main_gen(t)
|
|
return list(gen)
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, [t + 1, t + 2, t + 3])
|
|
|
|
def test_return_subgenerator(self):
|
|
def subgen(t):
|
|
yield t + 1
|
|
yield t + 2
|
|
|
|
def main_gen(t):
|
|
yield from subgen(t)
|
|
yield t + 3
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = main_gen(t)
|
|
next(gen)
|
|
return gen
|
|
|
|
t = torch.randn(2)
|
|
gen = fn(t)
|
|
self.assertEqual(list(gen), [t + 2, t + 3])
|
|
|
|
def test_dynamo_disable_generator(self):
|
|
@torch._dynamo.disable
|
|
def main_gen(t):
|
|
yield t + 1
|
|
yield t + 2
|
|
yield t + 3
|
|
|
|
@torch.compile(backend="eager", fullgraph=False)
|
|
def fn(t):
|
|
gen = main_gen(t)
|
|
return list(gen)
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, [t + 1, t + 2, t + 3])
|
|
|
|
def test_dynamo_disable_sub_generator(self):
|
|
@torch._dynamo.disable
|
|
def subgen(t):
|
|
yield t + 2
|
|
yield t + 3
|
|
|
|
def main_gen(t):
|
|
yield t + 1
|
|
yield from subgen(t)
|
|
|
|
@torch.compile(backend="eager", fullgraph=False)
|
|
def fn(t):
|
|
gen = main_gen(t)
|
|
return list(gen)
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, [t + 1, t + 2, t + 3])
|
|
|
|
def test_graph_break_outside_generator(self):
|
|
def whoo(t):
|
|
yield t + 1
|
|
yield t + 2
|
|
|
|
@torch.compile(backend="eager", fullgraph=False)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
x = next(gen)
|
|
torch._dynamo.graph_break()
|
|
y = next(gen)
|
|
return x + y
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, (t + 1) + (t + 2))
|
|
|
|
def test_graph_break_before_calling_generator(self):
|
|
def whoo(t):
|
|
for perm in itertools.product(itertools.permutations((0, 1, 2)), repeat=1):
|
|
yield sum(perm[0])
|
|
|
|
def fn(t):
|
|
s = 0
|
|
for b, p in itertools.product(whoo(t), itertools.permutations((4, 5))):
|
|
s += b
|
|
return s
|
|
|
|
t = torch.randn(2)
|
|
expected = fn(t)
|
|
got = torch.compile(backend="eager", fullgraph=False)(fn)(t)
|
|
self.assertEqual(expected, got)
|
|
|
|
def test_generator_with_side_effects(self):
|
|
counters.clear()
|
|
i = 0
|
|
|
|
def whoo(t):
|
|
nonlocal i
|
|
for j in range(5):
|
|
i += 1
|
|
yield t + j
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
return whoo(t), t.sin()
|
|
|
|
t = torch.randn(2)
|
|
fn(t)
|
|
self.assertEqual(len(counters["unimplemented"]), 1)
|
|
self.assertEqual(
|
|
dict(counters["unimplemented"]),
|
|
{
|
|
"Cannot reconstruct a generator with variable mutations. "
|
|
"Dynamo needs to fully exhaust the generator, which may cause "
|
|
"unintended variable modifications.": 1
|
|
},
|
|
)
|
|
|
|
def test_subgenerator_with_side_effects(self):
|
|
i = 0
|
|
|
|
def subgen(t):
|
|
nonlocal i
|
|
i += 1
|
|
yield t
|
|
i += 1
|
|
yield t + 1
|
|
|
|
def whoo(t):
|
|
nonlocal i
|
|
yield from subgen(t)
|
|
i += 1
|
|
yield t + 2
|
|
i += 1
|
|
yield t + 3
|
|
i += 1
|
|
yield t + 4
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
return whoo(t), t.sin()
|
|
|
|
t = torch.randn(2)
|
|
gen, y = fn(t)
|
|
self.assertEqual(y, t.sin())
|
|
self.assertEqual(len(list(gen)), 5)
|
|
self.assertTrue(
|
|
"Cannot reconstruct a generator with variable mutations. "
|
|
"Dynamo needs to fully exhaust the generator, which may cause "
|
|
"unintended variable modifications." in dict(counters["unimplemented"])
|
|
)
|
|
|
|
def test_generator_with_side_effects_graph_break(self):
|
|
i = 0
|
|
|
|
def whoo(t):
|
|
nonlocal i
|
|
for j in range(5):
|
|
i += 1
|
|
yield t + j
|
|
|
|
@torch.compile(backend="eager", fullgraph=False)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
torch._dynamo.graph_break()
|
|
next(gen)
|
|
return gen, t.sin()
|
|
|
|
t = torch.randn(2)
|
|
gen, y = fn(t)
|
|
self.assertEqual(y, t.sin())
|
|
self.assertEqual(len(list(gen)), 4)
|
|
self.assertTrue(
|
|
"Cannot reconstruct a generator with variable mutations. "
|
|
"Dynamo needs to fully exhaust the generator, which may cause "
|
|
"unintended variable modifications." in dict(counters["unimplemented"])
|
|
)
|
|
|
|
def test_generator_with_side_effects_graph_break_2(self):
|
|
i = 0
|
|
|
|
def whoo(t):
|
|
nonlocal i
|
|
for j in range(5):
|
|
i += 1
|
|
yield t + j
|
|
torch._dynamo.graph_break()
|
|
|
|
eager = EagerAndRecordGraphs()
|
|
|
|
@torch.compile(backend=eager, fullgraph=False)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
return list(zip(range(3), gen))
|
|
|
|
t = torch.randn(2)
|
|
fn(t)
|
|
self.assertEqual(len(eager.graphs), 0)
|
|
|
|
@unittest.skipIf(sys.version_info < (3, 12), "Test CLEANUP_THROW")
|
|
@unittest.expectedFailure
|
|
def test_cleanup_throw(self):
|
|
def nested_generator():
|
|
try:
|
|
yield 1
|
|
yield 2
|
|
except StopIteration:
|
|
return 123 # noqa: B901
|
|
|
|
def outer_generator():
|
|
yield from nested_generator()
|
|
yield 3
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = outer_generator()
|
|
next(gen) # Start the outer generator and enter the nested generato
|
|
|
|
i = 0
|
|
try:
|
|
# Force an exception while the generator is running
|
|
i = gen.throw(StopIteration("stop"))
|
|
except RuntimeError:
|
|
pass
|
|
return (i, t.sin())
|
|
|
|
t = torch.randn(2)
|
|
i, y = self._compile_check(fn, args=(t,))
|
|
self.assertEqual(i, 3)
|
|
self.assertEqual(y, t.sin())
|
|
|
|
def test_iter(self):
|
|
def whoo():
|
|
i = 0
|
|
while True:
|
|
yield i
|
|
i += 1
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
s = 0
|
|
for i in whoo():
|
|
if i > 5:
|
|
break
|
|
s += i
|
|
return t + s
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, t + sum(range(6)))
|
|
|
|
def test_list_extend(self):
|
|
def f(x):
|
|
y = [1]
|
|
y.extend(y[-1] + z for z in range(3))
|
|
return x + 1, y
|
|
|
|
self.assertEqual(
|
|
f(torch.ones(3)),
|
|
torch.compile(f, backend="eager", fullgraph=True)(torch.ones(3)),
|
|
)
|
|
|
|
def test_deque_extendleft(self):
|
|
import collections
|
|
|
|
def f(x):
|
|
y = collections.deque([1])
|
|
y.extendleft(y[0] + z for z in range(3))
|
|
return x + 1, y
|
|
|
|
self.assertEqual(
|
|
f(torch.ones(3)),
|
|
torch.compile(f, backend="eager", fullgraph=True)(torch.ones(3)),
|
|
)
|
|
|
|
|
|
class TestGeneratorSend(GeneratorTestsBase):
|
|
def test_send(self):
|
|
def double():
|
|
x = yield
|
|
yield x * 2
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = double()
|
|
next(gen)
|
|
return gen.send(t)
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, t * 2)
|
|
|
|
@parametrize("fullgraph", [True, False])
|
|
def test_send_stop_iteration(self, fullgraph):
|
|
def double():
|
|
x = yield
|
|
yield x * 2
|
|
|
|
@torch.compile(backend="eager", fullgraph=fullgraph)
|
|
def fn(t):
|
|
gen = double()
|
|
next(gen)
|
|
a = gen.send(t)
|
|
b = gen.send(t) # should result in StopIteration
|
|
return a + b
|
|
|
|
t = torch.randn(2)
|
|
if fullgraph:
|
|
with self.assertRaisesRegex(Unsupported, "Observed exception"):
|
|
fn(t)
|
|
else:
|
|
with self.assertRaises(StopIteration):
|
|
fn(t)
|
|
|
|
|
|
class TestGeneratorClose(GeneratorTestsBase):
|
|
def test_close(self):
|
|
def whoo(t):
|
|
yield t.sin()
|
|
yield t.cos()
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
i = next(gen)
|
|
gen.close()
|
|
return i
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, t.sin())
|
|
|
|
def test_close_subgen(self):
|
|
z = 0
|
|
|
|
def subgen(t):
|
|
nonlocal z
|
|
z = 1
|
|
yield t.sin()
|
|
z = 3
|
|
yield t.cos()
|
|
|
|
def whoo(t):
|
|
yield from subgen(t)
|
|
yield t.tan()
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
i = next(gen)
|
|
gen.close()
|
|
return i
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, t.sin())
|
|
self.assertEqual(z, 1)
|
|
|
|
def test_close_with_side_effects(self):
|
|
L = []
|
|
z = 0
|
|
|
|
def whoo(t):
|
|
nonlocal z
|
|
try:
|
|
L.append(1)
|
|
yield t.sin()
|
|
L.append(2)
|
|
yield t.cos()
|
|
finally:
|
|
L.append(z)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
nonlocal z
|
|
gen = whoo(t)
|
|
i = next(gen)
|
|
z = -123
|
|
gen.close()
|
|
L.append(len(L))
|
|
return i
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, t.sin())
|
|
self.assertEqual(L, [1, -123, 2])
|
|
|
|
def test_close_capture_GeneratorExit_return(self):
|
|
z = 0
|
|
|
|
def whoo(t):
|
|
nonlocal z
|
|
try:
|
|
z += 1
|
|
yield t.sin()
|
|
yield t.cos()
|
|
except GeneratorExit:
|
|
z += 10
|
|
return t.tan() # noqa: B901
|
|
finally:
|
|
z += 100
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
nonlocal z
|
|
gen = whoo(t)
|
|
i = next(gen)
|
|
y = gen.close()
|
|
return (i, y)
|
|
|
|
t = torch.randn(2)
|
|
(i, y) = fn(t)
|
|
self.assertEqual(i, t.sin())
|
|
self.assertEqual(y, t.tan())
|
|
self.assertEqual(z, 111)
|
|
|
|
@parametrize("fullgraph", [True, False])
|
|
def test_close_capture_GeneratorExit(self, fullgraph):
|
|
z = 0
|
|
|
|
def whoo(t):
|
|
nonlocal z
|
|
try:
|
|
yield t.sin()
|
|
yield t.cos()
|
|
except GeneratorExit:
|
|
yield t.tan()
|
|
finally:
|
|
z = 1
|
|
|
|
@torch.compile(backend="eager", fullgraph=fullgraph)
|
|
def fn(t):
|
|
nonlocal z
|
|
gen = whoo(t)
|
|
i = next(gen)
|
|
gen.close()
|
|
return i
|
|
|
|
t = torch.randn(2)
|
|
if fullgraph:
|
|
# This should actually be RuntimeError("generator ignored GeneratorExit")
|
|
# but Dynamo swallow the exception and raises Unsupported instead
|
|
with self.assertRaisesRegex(Unsupported, "Observed exception"):
|
|
fn(t)
|
|
else:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "generator ignored GeneratorExit"
|
|
):
|
|
fn(t)
|
|
|
|
def test_close_capture_and_reraise_GeneratorExit(self):
|
|
L = []
|
|
z = 0
|
|
|
|
def whoo(t):
|
|
nonlocal z
|
|
try:
|
|
L.append(1)
|
|
yield t.sin()
|
|
yield t.cos()
|
|
except GeneratorExit:
|
|
L.append(z)
|
|
z = -1
|
|
raise
|
|
finally:
|
|
L.append(z)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
nonlocal z
|
|
gen = whoo(t)
|
|
i = next(gen)
|
|
z = -123
|
|
gen.close()
|
|
L.append(456)
|
|
return i
|
|
|
|
t = torch.randn(2)
|
|
y = fn(t)
|
|
self.assertEqual(y, t.sin())
|
|
self.assertEqual(L, [1, -123, -1, 456])
|
|
|
|
@parametrize("exc", [RuntimeError, AttributeError])
|
|
@make_dynamo_test
|
|
def test_close_capture_and_reraise_exc(self, exc):
|
|
def whoo(t):
|
|
try:
|
|
yield t.sin()
|
|
yield t.cos()
|
|
except GeneratorExit as e:
|
|
raise exc from e
|
|
finally:
|
|
pass
|
|
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
i = next(gen)
|
|
gen.close()
|
|
return i
|
|
|
|
t = torch.randn(2)
|
|
|
|
z = 0
|
|
try:
|
|
fn(t)
|
|
except exc:
|
|
z = 1
|
|
finally:
|
|
assert z == 1
|
|
|
|
def test_close_with_subgen(self):
|
|
L = []
|
|
z = 0
|
|
|
|
def subgen(t):
|
|
yield t.sin()
|
|
yield t.cos()
|
|
|
|
def whoo(t):
|
|
nonlocal z
|
|
L.append(10)
|
|
yield from subgen(t)
|
|
L.append(20)
|
|
try:
|
|
L.append(1)
|
|
z = 4
|
|
yield t.tan()
|
|
finally:
|
|
L.append(z)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
nonlocal z
|
|
gen = whoo(t)
|
|
i = next(gen)
|
|
z = -123
|
|
gen.close()
|
|
L.append(456)
|
|
return i, t.sin()
|
|
|
|
t = torch.randn(2)
|
|
y, _ = fn(t)
|
|
self.assertEqual(y, t.sin())
|
|
self.assertEqual(L, [10, 456])
|
|
self.assertEqual(z, -123)
|
|
|
|
def test_close_after_close(self):
|
|
z = 0
|
|
|
|
def whoo(t):
|
|
nonlocal z
|
|
try:
|
|
z += 1
|
|
yield t.sin()
|
|
yield t.cos()
|
|
finally:
|
|
# finally should only be executed once
|
|
z += 1
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
i = next(gen)
|
|
gen.close()
|
|
return (i, gen.close())
|
|
|
|
t = torch.randn(2)
|
|
(i, y) = fn(t)
|
|
self.assertEqual(i, t.sin())
|
|
self.assertEqual(y, None)
|
|
self.assertEqual(z, 2)
|
|
|
|
@parametrize("fullgraph", [True, False])
|
|
def test_next_after_close(self, fullgraph):
|
|
def whoo(t):
|
|
yield t.sin()
|
|
yield t.cos()
|
|
|
|
@torch.compile(backend="eager", fullgraph=fullgraph)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
gen.close()
|
|
a = next(gen)
|
|
return [t.sin(), a]
|
|
|
|
t = torch.randn(3)
|
|
if fullgraph:
|
|
with self.assertRaises(Unsupported):
|
|
fn(t)
|
|
else:
|
|
with self.assertRaises(StopIteration):
|
|
fn(t)
|
|
|
|
def test_close_after_exception(self):
|
|
def whoo(t):
|
|
raise ValueError("foo")
|
|
yield t.cos()
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
try:
|
|
next(gen)
|
|
except ValueError:
|
|
pass
|
|
b = gen.close()
|
|
return [t.sin(), b]
|
|
|
|
t = torch.randn(2)
|
|
y, b = fn(t)
|
|
self.assertEqual(y, t.sin())
|
|
self.assertIsNone(b)
|
|
|
|
def test_close_handling_finally(self):
|
|
z = 0
|
|
|
|
def whoo(t):
|
|
nonlocal z
|
|
try:
|
|
yield t.sin()
|
|
yield t.cos()
|
|
except GeneratorExit:
|
|
z += 1
|
|
return t.tan() # noqa: B901
|
|
finally:
|
|
z += 1
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
next(gen)
|
|
b = gen.close()
|
|
return t.sin(), b
|
|
|
|
t = torch.randn(2)
|
|
y, b = fn(t)
|
|
self.assertEqual(y, t.sin())
|
|
self.assertEqual(b, t.tan())
|
|
self.assertEqual(z, 2)
|
|
|
|
|
|
class TestGeneratorThrow(GeneratorTestsBase):
|
|
def test_throw(self):
|
|
def whoo(t):
|
|
try:
|
|
yield t.sin()
|
|
except RuntimeError:
|
|
yield t.cos()
|
|
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
a = next(gen)
|
|
b = gen.throw(RuntimeError)
|
|
return a + b
|
|
|
|
t = torch.randn(2)
|
|
y = self._compile_check(fn, (t,))
|
|
self.assertEqual(y, t.sin() + t.cos())
|
|
|
|
def test_throw_with_finally(self):
|
|
z = 0
|
|
|
|
def whoo():
|
|
nonlocal z
|
|
z = 0
|
|
try:
|
|
try:
|
|
yield 1
|
|
except ValueError:
|
|
yield 2
|
|
finally:
|
|
z += 2
|
|
except ValueError:
|
|
z += 33
|
|
yield 4
|
|
finally:
|
|
z += 1
|
|
z += 10
|
|
|
|
def f(x):
|
|
gen = whoo()
|
|
next(gen)
|
|
gen.throw(ValueError)
|
|
return x.sin()
|
|
|
|
self._compile_check(f)
|
|
self.assertEqual(z, 3)
|
|
|
|
def test_throw_without_finally(self):
|
|
z = 0
|
|
|
|
def whoo(t):
|
|
nonlocal z
|
|
z = 0
|
|
try:
|
|
z += 1
|
|
yield t.sin()
|
|
z += 10
|
|
except RuntimeError:
|
|
z += 100
|
|
yield t.cos()
|
|
z += 1_000
|
|
z += 10_000
|
|
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
a = next(gen)
|
|
b = gen.throw(RuntimeError)
|
|
return a + b
|
|
|
|
t = torch.randn(2)
|
|
y = self._compile_check(fn, (t,))
|
|
self.assertEqual(y, t.sin() + t.cos())
|
|
self.assertEqual(z, 101)
|
|
|
|
def test_throw_no_yield_after_throw(self):
|
|
z = 0
|
|
|
|
def whoo(t):
|
|
nonlocal z
|
|
z = 0
|
|
try:
|
|
z += 1
|
|
yield t.sin()
|
|
except ValueError:
|
|
z += 10
|
|
finally:
|
|
z += 100
|
|
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
a = next(gen)
|
|
try:
|
|
gen.throw(ValueError)
|
|
except StopIteration:
|
|
return a
|
|
|
|
t = torch.randn(2)
|
|
y = self._compile_check(fn, (t,))
|
|
self.assertEqual(z, 111)
|
|
self.assertEqual(y, t.sin())
|
|
|
|
def test_throw_not_catch(self):
|
|
z = 0
|
|
|
|
def whoo(t):
|
|
nonlocal z
|
|
z = 0
|
|
try:
|
|
z += 1
|
|
yield t.sin()
|
|
except ValueError:
|
|
z += 10
|
|
yield t.cos()
|
|
finally:
|
|
z += 100
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
a = next(gen)
|
|
b = gen.throw(RuntimeError)
|
|
return a + b
|
|
|
|
t = torch.randn(2)
|
|
with self.assertRaises(RuntimeError):
|
|
fn(t)
|
|
|
|
def test_throw_raise_difference_exc(self):
|
|
z = 0
|
|
|
|
def whoo(t):
|
|
nonlocal z
|
|
z = 0
|
|
try:
|
|
z += 1
|
|
yield t.sin()
|
|
except ValueError as e:
|
|
z += 10
|
|
raise RuntimeError from e
|
|
finally:
|
|
z += 100
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
a = next(gen)
|
|
b = gen.throw(ValueError)
|
|
return a + b
|
|
|
|
t = torch.randn(2)
|
|
with self.assertRaises(RuntimeError):
|
|
fn(t)
|
|
|
|
def test_throw_yield_finally(self):
|
|
z = 0
|
|
|
|
def whoo(t):
|
|
nonlocal z
|
|
z = 0
|
|
try:
|
|
z += 1
|
|
yield t.sin()
|
|
except RuntimeError:
|
|
z += 10
|
|
yield t.cos()
|
|
finally:
|
|
z += 100
|
|
yield t.tan() # RuntimeError: generator ignored GeneratorExit
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
a = next(gen)
|
|
b = gen.throw(RuntimeError)
|
|
return a + b
|
|
|
|
t = torch.randn(2)
|
|
with self.assertRaises(Unsupported):
|
|
fn(t)
|
|
|
|
def test_throw_try_except_finally(self):
|
|
z = 0
|
|
|
|
def whoo(t):
|
|
nonlocal z
|
|
z = 0
|
|
try:
|
|
z += 1
|
|
yield t.sin()
|
|
except ValueError:
|
|
z += 10
|
|
yield t.cos()
|
|
except RuntimeError:
|
|
z += 100
|
|
yield t.tan()
|
|
finally:
|
|
z += 1000
|
|
z += 10_000
|
|
|
|
def fn(t):
|
|
gen = whoo(t)
|
|
a = next(gen)
|
|
b = gen.throw(RuntimeError)
|
|
return a + b
|
|
|
|
t = torch.randn(2)
|
|
y = self._compile_check(fn, (t,))
|
|
self.assertEqual(y, t.sin() + t.tan())
|
|
self.assertEqual(z, 1 + 100 + 1000)
|
|
|
|
def test_exception_context_with_yield(self):
|
|
def f():
|
|
yield
|
|
|
|
def fn(t):
|
|
gen = f()
|
|
gen.send(None)
|
|
try:
|
|
gen.throw(ValueError)
|
|
except ValueError:
|
|
z = 1
|
|
except Exception as e:
|
|
raise AssertionError from e
|
|
assert z == 1
|
|
return t.sin()
|
|
|
|
self._compile_check(fn)
|
|
|
|
|
|
instantiate_parametrized_tests(GeneratorTests)
|
|
instantiate_parametrized_tests(TestGeneratorSend)
|
|
instantiate_parametrized_tests(TestGeneratorClose)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|