mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
redirect iter(range)
to range.__iter__()
(#161803)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161803 Approved by: https://github.com/anijain2305 ghstack dependencies: #161801, #161802
This commit is contained in:
committed by
PyTorch MergeBot
parent
485a7bd82e
commit
c8255c67cd
@ -3529,7 +3529,6 @@ class GraphModule(torch.nn.Module):
|
||||
return a + b
|
||||
return a - b
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_test
|
||||
def test_range_iterator_2(a, b):
|
||||
# should pass once we stop having three different paths on call_iter
|
||||
|
@ -1820,6 +1820,8 @@ class BuiltinVariable(VariableTracker):
|
||||
def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs):
|
||||
if isinstance(obj, variables.IteratorVariable):
|
||||
ret = obj
|
||||
elif isinstance(obj, variables.RangeVariable):
|
||||
ret = obj.call_method(tx, "__iter__", [], {})
|
||||
else:
|
||||
# Handle the case where we are iterating over a tuple, list or iterator
|
||||
ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
|
||||
|
Reference in New Issue
Block a user