mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] add itertools repeat/count bytecode reconstruction (#131716)
Also fix bugs in the count iterator variable implementation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131716 Approved by: https://github.com/anijain2305 ghstack dependencies: #131413
This commit is contained in:
committed by
PyTorch MergeBot
parent
40cc5c0697
commit
35b4de32fa
@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
from .. import polyfill, variables
|
||||
from ..bytecode_transformation import create_instruction
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..exc import (
|
||||
handle_observed_user_stop_iteration,
|
||||
ObservedUserStopIteration,
|
||||
@ -240,6 +240,18 @@ class RepeatIteratorVariable(IteratorVariable):
|
||||
def next_variable(self, tx):
|
||||
return self.item
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
[
|
||||
codegen.create_load_python_module(itertools),
|
||||
codegen.create_load_attr("repeat"),
|
||||
]
|
||||
)
|
||||
)
|
||||
codegen(self.item)
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
|
||||
class CountIteratorVariable(IteratorVariable):
|
||||
def __init__(self, item: int = 0, step: int = 1, **kwargs):
|
||||
@ -253,10 +265,23 @@ class CountIteratorVariable(IteratorVariable):
|
||||
|
||||
def next_variable(self, tx):
|
||||
assert self.mutable_local
|
||||
old_item = self.item
|
||||
tx.output.side_effects.mutation(self)
|
||||
next_item = self.item.call_method(tx, "__add__", [self.step], {})
|
||||
self.item = next_item
|
||||
return self.item
|
||||
self.item = self.item.call_method(tx, "__add__", [self.step], {})
|
||||
return old_item
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
[
|
||||
codegen.create_load_python_module(itertools),
|
||||
codegen.create_load_attr("count"),
|
||||
]
|
||||
)
|
||||
)
|
||||
codegen(self.item)
|
||||
codegen(self.step)
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
|
||||
class CycleIteratorVariable(IteratorVariable):
|
||||
|
||||
Reference in New Issue
Block a user