[dynamo] allow global import from collections import deque in user code (#148676)

See https://github.com/pytorch/pytorch/pull/148669#discussion_r1983462218 for more details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148676
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan
2025-03-09 13:28:13 +08:00
committed by PyTorch MergeBot
parent 59f14d19ae
commit 098494e9cb
3 changed files with 86 additions and 10 deletions

View File

@ -0,0 +1,78 @@
# Owner(s): ["module: dynamo"]
import collections
import contextlib
import torch
import torch._inductor.test_case
class TestDequeReconstruct(torch._inductor.test_case.TestCase):
UNSET = object()
@contextlib.contextmanager
def set_deque_in_globals(self, value):
prev = globals().pop("deque", self.UNSET)
assert "deque" not in globals()
try:
if value is not self.UNSET:
globals()["deque"] = value
yield
finally:
if prev is self.UNSET:
globals().pop("deque", None)
assert "deque" not in globals()
else:
globals()["deque"] = prev
def test_deque_reconstruct_not_in_globals(self):
with self.set_deque_in_globals(self.UNSET):
@torch.compile(backend="eager", fullgraph=True)
def func(x):
return collections.deque([x, x + 1, x + 2], maxlen=2)
x = torch.randn(3, 4)
out = func(x)
self.assertIsInstance(out, collections.deque)
self.assertEqual(out.maxlen, 2)
self.assertEqual(out, collections.deque([x + 1, x + 2], maxlen=2))
def test_deque_reconstruct_in_globals(self):
with self.set_deque_in_globals(collections.deque):
# This does not emit a NameError
dummy = deque([0, 1, 2], maxlen=2) # noqa: F821
self.assertIsInstance(dummy, collections.deque)
self.assertEqual(list(dummy), [1, 2])
@torch.compile(backend="eager", fullgraph=True)
def func(x):
return collections.deque([x, x + 1, x + 2], maxlen=2)
x = torch.randn(3, 4)
out = func(x)
self.assertIsInstance(out, collections.deque)
self.assertEqual(out.maxlen, 2)
self.assertEqual(out, collections.deque([x + 1, x + 2], maxlen=2))
def test_deque_reconstruct_shallows_globals(self):
with self.set_deque_in_globals(None):
# This does not emit a NameError
self.assertIsNone(deque) # noqa: F821
@torch.compile(backend="eager", fullgraph=True)
def func(x):
return collections.deque([x, x + 1, x + 2], maxlen=2)
x = torch.randn(3, 4)
out = func(x)
self.assertIsInstance(out, collections.deque)
self.assertEqual(out.maxlen, 2)
self.assertEqual(out, collections.deque([x + 1, x + 2], maxlen=2))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: pytree"]
import collections
import enum
import inspect
import os
@ -9,7 +8,7 @@ import subprocess
import sys
import time
import unittest
from collections import defaultdict, namedtuple, OrderedDict, UserDict
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
from dataclasses import dataclass
from enum import auto
from typing import Any, NamedTuple
@ -405,7 +404,7 @@ class TestGenericPytree(TestCase):
(
py_pytree,
lambda deq: py_pytree.TreeSpec(
collections.deque,
deque,
deq.maxlen,
[py_pytree.LeafSpec() for _ in deq],
),
@ -416,7 +415,7 @@ class TestGenericPytree(TestCase):
(
cxx_pytree,
lambda deq: cxx_pytree.tree_structure(
collections.deque(deq, maxlen=deq.maxlen)
deque(deq, maxlen=deq.maxlen)
),
),
name="cxx",
@ -434,11 +433,11 @@ class TestGenericPytree(TestCase):
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, deq)
self.assertEqual(unflattened.maxlen, deq.maxlen)
self.assertIsInstance(unflattened, collections.deque)
self.assertIsInstance(unflattened, deque)
run_test(collections.deque([]))
run_test(collections.deque([1.0, 2]))
run_test(collections.deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
run_test(deque([]))
run_test(deque([1.0, 2]))
run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
@parametrize(
"pytree_impl",
@ -1470,7 +1469,7 @@ if "optree" in sys.modules:
namedtuple: namedtuple("ANamedTuple", ["x", "y"])(1, 2),
OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
collections.deque: collections.deque([1, 2, 3]),
deque: deque([1, 2, 3]),
torch.Size: torch.Size([1, 2, 3]),
immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
immutable_list: immutable_list([1, 2, 3]),

View File

@ -563,7 +563,6 @@ class DequeVariable(CommonListMethodsVariable):
)
def reconstruct(self, codegen: "PyCodegen") -> None:
assert "deque" not in codegen.tx.f_globals
codegen.add_push_null(
lambda: codegen.append_output(
codegen.create_load_python_module(collections.deque)