mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] allow global import from collections import deque
in user code (#148676)
See https://github.com/pytorch/pytorch/pull/148669#discussion_r1983462218 for more details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148676 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
59f14d19ae
commit
098494e9cb
78
test/dynamo/test_deque_reconstruct.py
Normal file
78
test/dynamo/test_deque_reconstruct.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
import torch._inductor.test_case
|
||||
|
||||
|
||||
class TestDequeReconstruct(torch._inductor.test_case.TestCase):
|
||||
UNSET = object()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_deque_in_globals(self, value):
|
||||
prev = globals().pop("deque", self.UNSET)
|
||||
assert "deque" not in globals()
|
||||
|
||||
try:
|
||||
if value is not self.UNSET:
|
||||
globals()["deque"] = value
|
||||
yield
|
||||
finally:
|
||||
if prev is self.UNSET:
|
||||
globals().pop("deque", None)
|
||||
assert "deque" not in globals()
|
||||
else:
|
||||
globals()["deque"] = prev
|
||||
|
||||
def test_deque_reconstruct_not_in_globals(self):
|
||||
with self.set_deque_in_globals(self.UNSET):
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def func(x):
|
||||
return collections.deque([x, x + 1, x + 2], maxlen=2)
|
||||
|
||||
x = torch.randn(3, 4)
|
||||
out = func(x)
|
||||
self.assertIsInstance(out, collections.deque)
|
||||
self.assertEqual(out.maxlen, 2)
|
||||
self.assertEqual(out, collections.deque([x + 1, x + 2], maxlen=2))
|
||||
|
||||
def test_deque_reconstruct_in_globals(self):
|
||||
with self.set_deque_in_globals(collections.deque):
|
||||
# This does not emit a NameError
|
||||
dummy = deque([0, 1, 2], maxlen=2) # noqa: F821
|
||||
self.assertIsInstance(dummy, collections.deque)
|
||||
self.assertEqual(list(dummy), [1, 2])
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def func(x):
|
||||
return collections.deque([x, x + 1, x + 2], maxlen=2)
|
||||
|
||||
x = torch.randn(3, 4)
|
||||
out = func(x)
|
||||
self.assertIsInstance(out, collections.deque)
|
||||
self.assertEqual(out.maxlen, 2)
|
||||
self.assertEqual(out, collections.deque([x + 1, x + 2], maxlen=2))
|
||||
|
||||
def test_deque_reconstruct_shallows_globals(self):
|
||||
with self.set_deque_in_globals(None):
|
||||
# This does not emit a NameError
|
||||
self.assertIsNone(deque) # noqa: F821
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def func(x):
|
||||
return collections.deque([x, x + 1, x + 2], maxlen=2)
|
||||
|
||||
x = torch.randn(3, 4)
|
||||
out = func(x)
|
||||
self.assertIsInstance(out, collections.deque)
|
||||
self.assertEqual(out.maxlen, 2)
|
||||
self.assertEqual(out, collections.deque([x + 1, x + 2], maxlen=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: pytree"]
|
||||
|
||||
import collections
|
||||
import enum
|
||||
import inspect
|
||||
import os
|
||||
@ -9,7 +8,7 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
from collections import defaultdict, namedtuple, OrderedDict, UserDict
|
||||
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
|
||||
from dataclasses import dataclass
|
||||
from enum import auto
|
||||
from typing import Any, NamedTuple
|
||||
@ -405,7 +404,7 @@ class TestGenericPytree(TestCase):
|
||||
(
|
||||
py_pytree,
|
||||
lambda deq: py_pytree.TreeSpec(
|
||||
collections.deque,
|
||||
deque,
|
||||
deq.maxlen,
|
||||
[py_pytree.LeafSpec() for _ in deq],
|
||||
),
|
||||
@ -416,7 +415,7 @@ class TestGenericPytree(TestCase):
|
||||
(
|
||||
cxx_pytree,
|
||||
lambda deq: cxx_pytree.tree_structure(
|
||||
collections.deque(deq, maxlen=deq.maxlen)
|
||||
deque(deq, maxlen=deq.maxlen)
|
||||
),
|
||||
),
|
||||
name="cxx",
|
||||
@ -434,11 +433,11 @@ class TestGenericPytree(TestCase):
|
||||
unflattened = pytree_impl.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, deq)
|
||||
self.assertEqual(unflattened.maxlen, deq.maxlen)
|
||||
self.assertIsInstance(unflattened, collections.deque)
|
||||
self.assertIsInstance(unflattened, deque)
|
||||
|
||||
run_test(collections.deque([]))
|
||||
run_test(collections.deque([1.0, 2]))
|
||||
run_test(collections.deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
|
||||
run_test(deque([]))
|
||||
run_test(deque([1.0, 2]))
|
||||
run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
|
||||
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
@ -1470,7 +1469,7 @@ if "optree" in sys.modules:
|
||||
namedtuple: namedtuple("ANamedTuple", ["x", "y"])(1, 2),
|
||||
OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
|
||||
defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
|
||||
collections.deque: collections.deque([1, 2, 3]),
|
||||
deque: deque([1, 2, 3]),
|
||||
torch.Size: torch.Size([1, 2, 3]),
|
||||
immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
|
||||
immutable_list: immutable_list([1, 2, 3]),
|
||||
|
@ -563,7 +563,6 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
assert "deque" not in codegen.tx.f_globals
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.append_output(
|
||||
codegen.create_load_python_module(collections.deque)
|
||||
|
Reference in New Issue
Block a user