mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Implement generator.send(..)
(#144422)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144422 Approved by: https://github.com/zou3519 ghstack dependencies: #141055, #144421
This commit is contained in:
committed by
PyTorch MergeBot
parent
d798831167
commit
ca9b16e070
@ -613,13 +613,51 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(y, t + sum(range(6)))
|
||||
|
||||
|
||||
class TestGeneratorSend(GeneratorTestsBase):
|
||||
def test_send(self):
|
||||
def double():
|
||||
x = yield
|
||||
yield x * 2
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = double()
|
||||
next(gen)
|
||||
return gen.send(t)
|
||||
|
||||
t = torch.randn(2)
|
||||
y = fn(t)
|
||||
self.assertEqual(y, t * 2)
|
||||
|
||||
@parametrize("fullgraph", [True, False])
|
||||
def test_send_stop_iteration(self, fullgraph):
|
||||
def double():
|
||||
x = yield
|
||||
yield x * 2
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=fullgraph)
|
||||
def fn(t):
|
||||
gen = double()
|
||||
next(gen)
|
||||
a = gen.send(t)
|
||||
b = gen.send(t) # should result in StopIteration
|
||||
return a + b
|
||||
|
||||
t = torch.randn(2)
|
||||
if fullgraph:
|
||||
with self.assertRaisesRegex(Unsupported, "Observed exception"):
|
||||
fn(t)
|
||||
else:
|
||||
with self.assertRaises(StopIteration):
|
||||
fn(t)
|
||||
|
||||
|
||||
class GeneratorCPythonTests(GeneratorTestsBase):
|
||||
# Taken from commit
|
||||
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
||||
# changed the tests a little bit to run them inside dynamo
|
||||
# + replaced all self.assert* calls to plain assert statements
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_send_non_none_to_new_gen(self):
|
||||
def f():
|
||||
yield 1
|
||||
@ -661,6 +699,7 @@ class GeneratorCPythonTests(GeneratorTestsBase):
|
||||
|
||||
|
||||
instantiate_parametrized_tests(GeneratorTests)
|
||||
instantiate_parametrized_tests(TestGeneratorSend)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -290,6 +290,11 @@ class ObservedNotImplementedError(ObservedException):
|
||||
pass
|
||||
|
||||
|
||||
class ObservedTypeError(ObservedException):
|
||||
# A TypeError exception to be raised from inside Dynamo tracing. This can happen on generator.send(..) method
|
||||
pass
|
||||
|
||||
|
||||
observed_exception_map = {
|
||||
StopIteration: ObservedUserStopIteration,
|
||||
LookupError: ObservedLookupError,
|
||||
@ -299,6 +304,7 @@ observed_exception_map = {
|
||||
AttributeError: ObservedAttributeError,
|
||||
RuntimeError: ObservedRuntimeError,
|
||||
NotImplementedError: ObservedNotImplementedError,
|
||||
TypeError: ObservedTypeError,
|
||||
}
|
||||
|
||||
|
||||
|
@ -490,6 +490,9 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
||||
break
|
||||
return result
|
||||
|
||||
def _is_generator_just_started(self):
|
||||
return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
@ -502,6 +505,21 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
||||
elif name == "__iter__":
|
||||
# iter(gen) returns itself
|
||||
return self
|
||||
elif name == "send":
|
||||
# Sends a value into the generator function. Returns the next value
|
||||
# yielded by the generator, or raises StopIteration if the generator
|
||||
# exits without yielding another value
|
||||
if self._is_generator_just_started() and len(args):
|
||||
# can't send non-None value to a just-started generator
|
||||
# Test: GeneratorCPythonTests.test_send_non_none_to_new_gen
|
||||
if not all(
|
||||
isinstance(arg, ConstantVariable) and arg.value is None
|
||||
for arg in args
|
||||
):
|
||||
raise_observed_exception(TypeError, tx)
|
||||
tracer = self._get_inline_tracer(tx)
|
||||
tracer.push_many(args)
|
||||
return self.next_variable(tx)
|
||||
|
||||
super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user