Add support for GET_YIELD_FROM_ITER, YIELD_FROM, SEND (#106986)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106986
Approved by: https://github.com/jansel
This commit is contained in:
Michael Voznesensky
2023-08-19 17:29:02 +00:00
committed by PyTorch MergeBot
parent 4f3284e3ed
commit 02c2b750c5
2 changed files with 108 additions and 0 deletions

View File

@ -6632,6 +6632,67 @@ def ___make_guard_fn():
self.assertEqual(counter.frame_count, 1)
self.assertTrue(isinstance(compiled, torch.Tensor))
def test_yield_from(self):
def yield_from_fn(t_list, k):
def yield_from_gen(l):
l2 = [t * k for t in l]
yield from l2
return [t * k for t in yield_from_gen(t_list)]
t_list = [torch.randn([2, 3])] * 3
multiplier = torch.tensor([10])
eager = yield_from_fn(t_list, 2)
counter = CompileCounter()
compiled = torch._dynamo.optimize(counter)(yield_from_fn)(t_list, 2)
self.assertEqual(eager, compiled)
self.assertEqual(counter.frame_count, 1)
def test_yield_gen_and_from(self):
def populate_and_multiply_sequence(n, multiplier):
# Inline generator
def tensor_generator():
for i in range(n):
yield torch.tensor([i])
# Use 'yield from' to iterate over tensors and multiply
t_list = [tensor * multiplier for tensor in tensor_generator()]
def yield_from_gen():
yield from t_list
return [t for t in yield_from_gen()]
multiplier = torch.tensor([10])
eager = populate_and_multiply_sequence(5, multiplier)
counter = CompileCounter()
compiled = torch._dynamo.optimize(counter)(populate_and_multiply_sequence)(
5, multiplier
)
self.assertEqual(eager, compiled)
self.assertEqual(counter.frame_count, 1)
def test_yield_send_to_subgenerator_graph_break(self):
def subgenerator(tensor):
multiplier = yield
yield tensor * multiplier
def main_generator(t_list):
for tensor in t_list:
subgen = subgenerator(tensor)
next(subgen)
yield from subgen.send(torch.tensor([10]))
t_list = [torch.tensor([i]) for i in range(5)]
eager = list(main_generator(t_list))
counter = CompileCounter()
compiled_fn = torch._dynamo.optimize(counter)(main_generator)
compiled = list(compiled_fn(t_list))
self.assertEqual(eager, compiled)
self.assertEqual(counter.frame_count, 0)
class TestTracer(JitTestCase):
def test_jit_save(self):

View File

@ -2444,3 +2444,50 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
self.generated_items.append(self.pop())
# TODO(jansel): figure out why this is needed, it isn't in the docs for YIELD_VALUE
self.push(ConstantVariable(None))
def GET_YIELD_FROM_ITER(self, inst):
tos = self.stack[-1]
if not isinstance(tos, ListIteratorVariable):
self.pop()
res = BuiltinVariable(iter).call_function(self, [tos], {})
self.push(res)
return self.YIELD_FROM(inst)
def YIELD_FROM(self, inst):
while True:
tos = self.stack[-1]
if isinstance(tos, ConstantVariable) and tos.value is None:
self.pop()
return
if isinstance(tos, ListIteratorVariable):
self.output.guards.update(tos.guards)
try:
val, next_iter = tos.next_variables()
self.replace_all(tos, next_iter)
self.push(val)
# TODO(voz): Unclear if we need the push None in YIELD_VALUE?
self.YIELD_VALUE(inst)
self.pop()
self.push(next_iter)
except StopIteration:
return
else:
unimplemented(f"YIELD_FROM {typestr(tos)}")
def SEND(self, inst):
assert len(self.stack) >= 2
val = self.pop()
tos = self.stack[-1]
if isinstance(tos, ListIteratorVariable):
if isinstance(val, ConstantVariable) and val.value is None:
self.push(val)
self.instruction_pointer = self.indexof[inst.target]
else:
# invoke send
# Unreachable code - if you hit this, you are implementing generator support and have
# lifted the `unimplemented("generator")` in frame conversion. This codepath handles
# subgenerator and lines up with this line in Python 3.11
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597
unimplemented("Unreachable sub-generator code")
else:
unimplemented(f"SEND {typestr(tos)}")