mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Support next(iterator, default)
(#159483)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159483 Approved by: https://github.com/mlazos ghstack dependencies: #159365, #159366, #159368
This commit is contained in:
committed by
PyTorch MergeBot
parent
e5621b4d8b
commit
b78968b4d1
@ -46,6 +46,7 @@ from .. import config, graph_break_hints, polyfills, variables
|
||||
from ..exc import (
|
||||
AttributeMutationError,
|
||||
ObservedAttributeError,
|
||||
ObservedUserStopIteration,
|
||||
raise_observed_exception,
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
@ -2140,9 +2141,14 @@ class BuiltinVariable(VariableTracker):
|
||||
def call_super(self, tx: "InstructionTranslator", a, b):
|
||||
return variables.SuperVariable(a, b)
|
||||
|
||||
def call_next(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
def call_next(self, tx: "InstructionTranslator", *args):
|
||||
arg = args[0]
|
||||
try:
|
||||
return arg.next_variable(tx)
|
||||
except ObservedUserStopIteration:
|
||||
if len(args) == 2:
|
||||
return args[1]
|
||||
raise
|
||||
except Unsupported as ex:
|
||||
if isinstance(arg, variables.BaseListVariable):
|
||||
ex.remove_from_stats()
|
||||
|
Reference in New Issue
Block a user