mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Support range_iterator as a function input (#138657)
Fixes https://github.com/pytorch/pytorch/issues/138654 Pull Request resolved: https://github.com/pytorch/pytorch/pull/138657 Approved by: https://github.com/williamwen42, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
e5c3d7ab77
commit
b1acd0978e
@ -37,7 +37,7 @@ import torch.library
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import nn
|
||||
from torch._dynamo.debug_utils import same_two_models
|
||||
from torch._dynamo.testing import CompileCounter, rand_strided, same
|
||||
from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
@ -6170,6 +6170,34 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
@skipIfPy312 # listcomp bytecode is optimized
|
||||
def test_listcomp(self):
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._num = 4
|
||||
|
||||
@torch._dynamo.disable(recursive=False)
|
||||
def forward(self, x):
|
||||
values = [i * torch.cos(x) for i in range(self._num)]
|
||||
return sum(values)
|
||||
|
||||
mod = Module()
|
||||
|
||||
def fn(x):
|
||||
return mod(x)
|
||||
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch.compile(fn, backend=cnt)
|
||||
x = torch.randn(4)
|
||||
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
# Ensure that the listcomp is fully compiled
|
||||
self.assertEqual(cnt.op_count, 8)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ReproTests)
|
||||
|
||||
|
@ -20,6 +20,7 @@ from torch.testing._internal.common_dtype import (
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict,
|
||||
parametrize,
|
||||
skipIfTorchDynamo,
|
||||
IS_WINDOWS)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
OpDTypes, expectedFailureMeta, instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU,
|
||||
@ -2589,7 +2590,7 @@ class TestReductions(TestCase):
|
||||
self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
|
||||
self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
|
||||
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/pull/138657 discovers a latent bug")
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_quantile(self, device, dtype):
|
||||
|
@ -1508,6 +1508,7 @@ dict_keys: Type[KeysView[Any]] = type({}.keys())
|
||||
dict_values: Type[ValuesView[Any]] = type({}.values())
|
||||
odict_values: Type[ValuesView[Any]] = type(collections.OrderedDict().values())
|
||||
tuple_iterator: Type[Iterator[Any]] = type(iter(()))
|
||||
range_iterator: Type[Iterator[Any]] = type(iter(range(0)))
|
||||
tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined]
|
||||
object_new = object.__new__
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
import abc
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
@ -106,6 +107,7 @@ from ..utils import (
|
||||
istype,
|
||||
odict_values,
|
||||
proxy_args_kwargs,
|
||||
range_iterator,
|
||||
set_example_value,
|
||||
tensor_always_has_static_shape,
|
||||
tuple_iterator,
|
||||
@ -153,6 +155,7 @@ from .iter import ItertoolsVariable
|
||||
from .lazy import LazyVariableTracker
|
||||
from .lists import (
|
||||
BaseListVariable,
|
||||
ListIteratorVariable,
|
||||
ListVariable,
|
||||
NamedTupleVariable,
|
||||
RangeVariable,
|
||||
@ -448,6 +451,7 @@ class VariableBuilder:
|
||||
cls.wrap_listlike,
|
||||
),
|
||||
(tuple_iterator, cls.wrap_tuple_iterator),
|
||||
(range_iterator, cls.wrap_range_iterator),
|
||||
((slice, range), cls.wrap_slice_range),
|
||||
(tuple(common_constant_types), cls.wrap_literal),
|
||||
(re.Pattern, cls.wrap_regex_pattern),
|
||||
@ -1312,6 +1316,12 @@ class VariableBuilder:
|
||||
|
||||
return self.set_source_and_track_mutable(value, result)
|
||||
|
||||
def wrap_range_iterator(self, value: range_iterator):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
# Get all the values from the range iterator
|
||||
items = [ConstantVariable.create(v) for v in copy.deepcopy(value)]
|
||||
return ListIteratorVariable(items, mutable_local=MutableLocal())
|
||||
|
||||
def wrap_slice_range(self, value: Union[slice, range]):
|
||||
items = [
|
||||
VariableBuilder(self.tx, AttrSource(self.get_source(), k))(
|
||||
|
Reference in New Issue
Block a user