Implement generator.send(..) (#144422)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144422
Approved by: https://github.com/zou3519
ghstack dependencies: #141055, #144421
This commit is contained in:
Guilherme Leobas
2025-02-07 14:55:20 -03:00
committed by PyTorch MergeBot
parent d798831167
commit ca9b16e070
3 changed files with 64 additions and 1 deletions

View File

@ -613,13 +613,51 @@ class GraphModule(torch.nn.Module):
self.assertEqual(y, t + sum(range(6)))
class TestGeneratorSend(GeneratorTestsBase):
def test_send(self):
def double():
x = yield
yield x * 2
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
gen = double()
next(gen)
return gen.send(t)
t = torch.randn(2)
y = fn(t)
self.assertEqual(y, t * 2)
@parametrize("fullgraph", [True, False])
def test_send_stop_iteration(self, fullgraph):
def double():
x = yield
yield x * 2
@torch.compile(backend="eager", fullgraph=fullgraph)
def fn(t):
gen = double()
next(gen)
a = gen.send(t)
b = gen.send(t) # should result in StopIteration
return a + b
t = torch.randn(2)
if fullgraph:
with self.assertRaisesRegex(Unsupported, "Observed exception"):
fn(t)
else:
with self.assertRaises(StopIteration):
fn(t)
class GeneratorCPythonTests(GeneratorTestsBase):
# Taken from commit
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
# changed the tests a little bit to run them inside dynamo
# + replaced all self.assert* calls to plain assert statements
@unittest.expectedFailure
def test_send_non_none_to_new_gen(self):
def f():
yield 1
@ -661,6 +699,7 @@ class GeneratorCPythonTests(GeneratorTestsBase):
instantiate_parametrized_tests(GeneratorTests)
instantiate_parametrized_tests(TestGeneratorSend)
if __name__ == "__main__":

View File

@ -290,6 +290,11 @@ class ObservedNotImplementedError(ObservedException):
pass
class ObservedTypeError(ObservedException):
# A TypeError exception to be raised from inside Dynamo tracing. This can happen on generator.send(..) method
pass
observed_exception_map = {
StopIteration: ObservedUserStopIteration,
LookupError: ObservedLookupError,
@ -299,6 +304,7 @@ observed_exception_map = {
AttributeError: ObservedAttributeError,
RuntimeError: ObservedRuntimeError,
NotImplementedError: ObservedNotImplementedError,
TypeError: ObservedTypeError,
}

View File

@ -490,6 +490,9 @@ class LocalGeneratorObjectVariable(VariableTracker):
break
return result
def _is_generator_just_started(self):
return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0
def call_method(
self,
tx: "InstructionTranslator",
@ -502,6 +505,21 @@ class LocalGeneratorObjectVariable(VariableTracker):
elif name == "__iter__":
# iter(gen) returns itself
return self
elif name == "send":
# Sends a value into the generator function. Returns the next value
# yielded by the generator, or raises StopIteration if the generator
# exits without yielding another value
if self._is_generator_just_started() and len(args):
# can't send non-None value to a just-started generator
# Test: GeneratorCPythonTests.test_send_non_none_to_new_gen
if not all(
isinstance(arg, ConstantVariable) and arg.value is None
for arg in args
):
raise_observed_exception(TypeError, tx)
tracer = self._get_inline_tracer(tx)
tracer.push_many(args)
return self.next_variable(tx)
super().call_method(tx, name, args, kwargs)