[dynamo] Support range_iterator as a function input (#138657)

Fixes https://github.com/pytorch/pytorch/issues/138654

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138657
Approved by: https://github.com/williamwen42, https://github.com/jansel
This commit is contained in:
Animesh Jain
2024-10-23 13:35:44 -07:00
committed by PyTorch MergeBot
parent e5c3d7ab77
commit b1acd0978e
4 changed files with 42 additions and 2 deletions

View File

@ -37,7 +37,7 @@ import torch.library
import torch.utils._pytree as pytree
from torch import nn
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import CompileCounter, rand_strided, same
from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312
from torch._inductor.utils import fresh_inductor_cache
from torch.nn import functional as F
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
@ -6170,6 +6170,34 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
self.assertEqual(ref, res)
@skipIfPy312 # listcomp bytecode is optimized
def test_listcomp(self):
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self._num = 4
@torch._dynamo.disable(recursive=False)
def forward(self, x):
values = [i * torch.cos(x) for i in range(self._num)]
return sum(values)
mod = Module()
def fn(x):
return mod(x)
cnt = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnt)
x = torch.randn(4)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnt.frame_count, 1)
# Ensure that the listcomp is fully compiled
self.assertEqual(cnt.op_count, 8)
instantiate_parametrized_tests(ReproTests)

View File

@ -20,6 +20,7 @@ from torch.testing._internal.common_dtype import (
from torch.testing._internal.common_utils import (
TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict,
parametrize,
skipIfTorchDynamo,
IS_WINDOWS)
from torch.testing._internal.common_device_type import (
OpDTypes, expectedFailureMeta, instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU,
@ -2589,7 +2590,7 @@ class TestReductions(TestCase):
self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/pull/138657 discovers a latent bug")
@onlyNativeDeviceTypes
@dtypes(torch.float, torch.double)
def test_quantile(self, device, dtype):

View File

@ -1508,6 +1508,7 @@ dict_keys: Type[KeysView[Any]] = type({}.keys())
dict_values: Type[ValuesView[Any]] = type({}.values())
odict_values: Type[ValuesView[Any]] = type(collections.OrderedDict().values())
tuple_iterator: Type[Iterator[Any]] = type(iter(()))
range_iterator: Type[Iterator[Any]] = type(iter(range(0)))
tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined]
object_new = object.__new__

View File

@ -3,6 +3,7 @@
import abc
import collections
import contextlib
import copy
import dataclasses
import enum
import functools
@ -106,6 +107,7 @@ from ..utils import (
istype,
odict_values,
proxy_args_kwargs,
range_iterator,
set_example_value,
tensor_always_has_static_shape,
tuple_iterator,
@ -153,6 +155,7 @@ from .iter import ItertoolsVariable
from .lazy import LazyVariableTracker
from .lists import (
BaseListVariable,
ListIteratorVariable,
ListVariable,
NamedTupleVariable,
RangeVariable,
@ -448,6 +451,7 @@ class VariableBuilder:
cls.wrap_listlike,
),
(tuple_iterator, cls.wrap_tuple_iterator),
(range_iterator, cls.wrap_range_iterator),
((slice, range), cls.wrap_slice_range),
(tuple(common_constant_types), cls.wrap_literal),
(re.Pattern, cls.wrap_regex_pattern),
@ -1312,6 +1316,12 @@ class VariableBuilder:
return self.set_source_and_track_mutable(value, result)
def wrap_range_iterator(self, value: range_iterator):
self.install_guards(GuardBuilder.TYPE_MATCH)
# Get all the values from the range iterator
items = [ConstantVariable.create(v) for v in copy.deepcopy(value)]
return ListIteratorVariable(items, mutable_local=MutableLocal())
def wrap_slice_range(self, value: Union[slice, range]):
items = [
VariableBuilder(self.tx, AttrSource(self.get_source(), k))(