[dynamo] allow global import from collections import deque in user code (#148676)

See https://github.com/pytorch/pytorch/pull/148669#discussion_r1983462218 for more details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148676
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan
2025-03-09 13:28:13 +08:00
committed by PyTorch MergeBot
parent 59f14d19ae
commit 098494e9cb
3 changed files with 86 additions and 10 deletions

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: pytree"]
import collections
import enum
import inspect
import os
@ -9,7 +8,7 @@ import subprocess
import sys
import time
import unittest
from collections import defaultdict, namedtuple, OrderedDict, UserDict
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
from dataclasses import dataclass
from enum import auto
from typing import Any, NamedTuple
@ -405,7 +404,7 @@ class TestGenericPytree(TestCase):
(
py_pytree,
lambda deq: py_pytree.TreeSpec(
collections.deque,
deque,
deq.maxlen,
[py_pytree.LeafSpec() for _ in deq],
),
@ -416,7 +415,7 @@ class TestGenericPytree(TestCase):
(
cxx_pytree,
lambda deq: cxx_pytree.tree_structure(
collections.deque(deq, maxlen=deq.maxlen)
deque(deq, maxlen=deq.maxlen)
),
),
name="cxx",
@ -434,11 +433,11 @@ class TestGenericPytree(TestCase):
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, deq)
self.assertEqual(unflattened.maxlen, deq.maxlen)
self.assertIsInstance(unflattened, collections.deque)
self.assertIsInstance(unflattened, deque)
run_test(collections.deque([]))
run_test(collections.deque([1.0, 2]))
run_test(collections.deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
run_test(deque([]))
run_test(deque([1.0, 2]))
run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
@parametrize(
"pytree_impl",
@ -1470,7 +1469,7 @@ if "optree" in sys.modules:
namedtuple: namedtuple("ANamedTuple", ["x", "y"])(1, 2),
OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
collections.deque: collections.deque([1, 2, 3]),
deque: deque([1, 2, 3]),
torch.Size: torch.Size([1, 2, 3]),
immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
immutable_list: immutable_list([1, 2, 3]),