[dynamo] add itertools repeat/count bytecode reconstruction (#131716)

Also fix bugs in the count iterator variable implementation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131716
Approved by: https://github.com/anijain2305
ghstack dependencies: #131413
This commit is contained in:
William Wen
2024-07-25 20:50:08 -07:00
committed by PyTorch MergeBot
parent 40cc5c0697
commit 35b4de32fa
2 changed files with 45 additions and 4 deletions

View File

@ -12,7 +12,7 @@ if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from .. import polyfill, variables
from ..bytecode_transformation import create_instruction
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import (
handle_observed_user_stop_iteration,
ObservedUserStopIteration,
@ -240,6 +240,18 @@ class RepeatIteratorVariable(IteratorVariable):
def next_variable(self, tx):
return self.item
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(itertools),
codegen.create_load_attr("repeat"),
]
)
)
codegen(self.item)
codegen.extend_output(create_call_function(1, False))
class CountIteratorVariable(IteratorVariable):
def __init__(self, item: int = 0, step: int = 1, **kwargs):
@ -253,10 +265,23 @@ class CountIteratorVariable(IteratorVariable):
def next_variable(self, tx):
assert self.mutable_local
old_item = self.item
tx.output.side_effects.mutation(self)
next_item = self.item.call_method(tx, "__add__", [self.step], {})
self.item = next_item
return self.item
self.item = self.item.call_method(tx, "__add__", [self.step], {})
return old_item
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(itertools),
codegen.create_load_attr("count"),
]
)
)
codegen(self.item)
codegen(self.step)
codegen.extend_output(create_call_function(2, False))
class CycleIteratorVariable(IteratorVariable):