mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for GET_YIELD_FROM_ITER, YIELD_FROM, SEND (#106986)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106986 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
4f3284e3ed
commit
02c2b750c5
@ -6632,6 +6632,67 @@ def ___make_guard_fn():
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertTrue(isinstance(compiled, torch.Tensor))
|
||||
|
||||
def test_yield_from(self):
|
||||
def yield_from_fn(t_list, k):
|
||||
def yield_from_gen(l):
|
||||
l2 = [t * k for t in l]
|
||||
yield from l2
|
||||
|
||||
return [t * k for t in yield_from_gen(t_list)]
|
||||
|
||||
t_list = [torch.randn([2, 3])] * 3
|
||||
multiplier = torch.tensor([10])
|
||||
eager = yield_from_fn(t_list, 2)
|
||||
counter = CompileCounter()
|
||||
compiled = torch._dynamo.optimize(counter)(yield_from_fn)(t_list, 2)
|
||||
self.assertEqual(eager, compiled)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
def test_yield_gen_and_from(self):
|
||||
def populate_and_multiply_sequence(n, multiplier):
|
||||
# Inline generator
|
||||
def tensor_generator():
|
||||
for i in range(n):
|
||||
yield torch.tensor([i])
|
||||
|
||||
# Use 'yield from' to iterate over tensors and multiply
|
||||
t_list = [tensor * multiplier for tensor in tensor_generator()]
|
||||
|
||||
def yield_from_gen():
|
||||
yield from t_list
|
||||
|
||||
return [t for t in yield_from_gen()]
|
||||
|
||||
multiplier = torch.tensor([10])
|
||||
eager = populate_and_multiply_sequence(5, multiplier)
|
||||
counter = CompileCounter()
|
||||
compiled = torch._dynamo.optimize(counter)(populate_and_multiply_sequence)(
|
||||
5, multiplier
|
||||
)
|
||||
self.assertEqual(eager, compiled)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
def test_yield_send_to_subgenerator_graph_break(self):
|
||||
def subgenerator(tensor):
|
||||
multiplier = yield
|
||||
yield tensor * multiplier
|
||||
|
||||
def main_generator(t_list):
|
||||
for tensor in t_list:
|
||||
subgen = subgenerator(tensor)
|
||||
next(subgen)
|
||||
yield from subgen.send(torch.tensor([10]))
|
||||
|
||||
t_list = [torch.tensor([i]) for i in range(5)]
|
||||
eager = list(main_generator(t_list))
|
||||
|
||||
counter = CompileCounter()
|
||||
compiled_fn = torch._dynamo.optimize(counter)(main_generator)
|
||||
compiled = list(compiled_fn(t_list))
|
||||
|
||||
self.assertEqual(eager, compiled)
|
||||
self.assertEqual(counter.frame_count, 0)
|
||||
|
||||
|
||||
class TestTracer(JitTestCase):
|
||||
def test_jit_save(self):
|
||||
|
||||
@ -2444,3 +2444,50 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
|
||||
self.generated_items.append(self.pop())
|
||||
# TODO(jansel): figure out why this is needed, it isn't in the docs for YIELD_VALUE
|
||||
self.push(ConstantVariable(None))
|
||||
|
||||
def GET_YIELD_FROM_ITER(self, inst):
|
||||
tos = self.stack[-1]
|
||||
if not isinstance(tos, ListIteratorVariable):
|
||||
self.pop()
|
||||
res = BuiltinVariable(iter).call_function(self, [tos], {})
|
||||
self.push(res)
|
||||
return self.YIELD_FROM(inst)
|
||||
|
||||
def YIELD_FROM(self, inst):
|
||||
while True:
|
||||
tos = self.stack[-1]
|
||||
if isinstance(tos, ConstantVariable) and tos.value is None:
|
||||
self.pop()
|
||||
return
|
||||
if isinstance(tos, ListIteratorVariable):
|
||||
self.output.guards.update(tos.guards)
|
||||
try:
|
||||
val, next_iter = tos.next_variables()
|
||||
self.replace_all(tos, next_iter)
|
||||
self.push(val)
|
||||
# TODO(voz): Unclear if we need the push None in YIELD_VALUE?
|
||||
self.YIELD_VALUE(inst)
|
||||
self.pop()
|
||||
self.push(next_iter)
|
||||
except StopIteration:
|
||||
return
|
||||
else:
|
||||
unimplemented(f"YIELD_FROM {typestr(tos)}")
|
||||
|
||||
def SEND(self, inst):
|
||||
assert len(self.stack) >= 2
|
||||
val = self.pop()
|
||||
tos = self.stack[-1]
|
||||
if isinstance(tos, ListIteratorVariable):
|
||||
if isinstance(val, ConstantVariable) and val.value is None:
|
||||
self.push(val)
|
||||
self.instruction_pointer = self.indexof[inst.target]
|
||||
else:
|
||||
# invoke send
|
||||
# Unreachable code - if you hit this, you are implementing generator support and have
|
||||
# lifted the `unimplemented("generator")` in frame conversion. This codepath handles
|
||||
# subgenerator and lines up with this line in Python 3.11
|
||||
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597
|
||||
unimplemented("Unreachable sub-generator code")
|
||||
else:
|
||||
unimplemented(f"SEND {typestr(tos)}")
|
||||
|
||||
Reference in New Issue
Block a user