mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/148669#discussion_r1983462218 for more details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148676 Approved by: https://github.com/jansel
79 lines
2.5 KiB
Python
79 lines
2.5 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import collections
|
|
import contextlib
|
|
|
|
import torch
|
|
import torch._inductor.test_case
|
|
|
|
|
|
class TestDequeReconstruct(torch._inductor.test_case.TestCase):
|
|
UNSET = object()
|
|
|
|
@contextlib.contextmanager
|
|
def set_deque_in_globals(self, value):
|
|
prev = globals().pop("deque", self.UNSET)
|
|
assert "deque" not in globals()
|
|
|
|
try:
|
|
if value is not self.UNSET:
|
|
globals()["deque"] = value
|
|
yield
|
|
finally:
|
|
if prev is self.UNSET:
|
|
globals().pop("deque", None)
|
|
assert "deque" not in globals()
|
|
else:
|
|
globals()["deque"] = prev
|
|
|
|
def test_deque_reconstruct_not_in_globals(self):
|
|
with self.set_deque_in_globals(self.UNSET):
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def func(x):
|
|
return collections.deque([x, x + 1, x + 2], maxlen=2)
|
|
|
|
x = torch.randn(3, 4)
|
|
out = func(x)
|
|
self.assertIsInstance(out, collections.deque)
|
|
self.assertEqual(out.maxlen, 2)
|
|
self.assertEqual(out, collections.deque([x + 1, x + 2], maxlen=2))
|
|
|
|
def test_deque_reconstruct_in_globals(self):
|
|
with self.set_deque_in_globals(collections.deque):
|
|
# This does not emit a NameError
|
|
dummy = deque([0, 1, 2], maxlen=2) # noqa: F821
|
|
self.assertIsInstance(dummy, collections.deque)
|
|
self.assertEqual(list(dummy), [1, 2])
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def func(x):
|
|
return collections.deque([x, x + 1, x + 2], maxlen=2)
|
|
|
|
x = torch.randn(3, 4)
|
|
out = func(x)
|
|
self.assertIsInstance(out, collections.deque)
|
|
self.assertEqual(out.maxlen, 2)
|
|
self.assertEqual(out, collections.deque([x + 1, x + 2], maxlen=2))
|
|
|
|
def test_deque_reconstruct_shallows_globals(self):
|
|
with self.set_deque_in_globals(None):
|
|
# This does not emit a NameError
|
|
self.assertIsNone(deque) # noqa: F821
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def func(x):
|
|
return collections.deque([x, x + 1, x + 2], maxlen=2)
|
|
|
|
x = torch.randn(3, 4)
|
|
out = func(x)
|
|
self.assertIsInstance(out, collections.deque)
|
|
self.assertEqual(out.maxlen, 2)
|
|
self.assertEqual(out, collections.deque([x + 1, x + 2], maxlen=2))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|