mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] allow global import from collections import deque
in user code (#148676)
See https://github.com/pytorch/pytorch/pull/148669#discussion_r1983462218 for more details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148676 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
59f14d19ae
commit
098494e9cb
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: pytree"]
|
||||
|
||||
import collections
|
||||
import enum
|
||||
import inspect
|
||||
import os
|
||||
@ -9,7 +8,7 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
from collections import defaultdict, namedtuple, OrderedDict, UserDict
|
||||
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
|
||||
from dataclasses import dataclass
|
||||
from enum import auto
|
||||
from typing import Any, NamedTuple
|
||||
@ -405,7 +404,7 @@ class TestGenericPytree(TestCase):
|
||||
(
|
||||
py_pytree,
|
||||
lambda deq: py_pytree.TreeSpec(
|
||||
collections.deque,
|
||||
deque,
|
||||
deq.maxlen,
|
||||
[py_pytree.LeafSpec() for _ in deq],
|
||||
),
|
||||
@ -416,7 +415,7 @@ class TestGenericPytree(TestCase):
|
||||
(
|
||||
cxx_pytree,
|
||||
lambda deq: cxx_pytree.tree_structure(
|
||||
collections.deque(deq, maxlen=deq.maxlen)
|
||||
deque(deq, maxlen=deq.maxlen)
|
||||
),
|
||||
),
|
||||
name="cxx",
|
||||
@ -434,11 +433,11 @@ class TestGenericPytree(TestCase):
|
||||
unflattened = pytree_impl.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, deq)
|
||||
self.assertEqual(unflattened.maxlen, deq.maxlen)
|
||||
self.assertIsInstance(unflattened, collections.deque)
|
||||
self.assertIsInstance(unflattened, deque)
|
||||
|
||||
run_test(collections.deque([]))
|
||||
run_test(collections.deque([1.0, 2]))
|
||||
run_test(collections.deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
|
||||
run_test(deque([]))
|
||||
run_test(deque([1.0, 2]))
|
||||
run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
|
||||
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
@ -1470,7 +1469,7 @@ if "optree" in sys.modules:
|
||||
namedtuple: namedtuple("ANamedTuple", ["x", "y"])(1, 2),
|
||||
OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
|
||||
defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
|
||||
collections.deque: collections.deque([1, 2, 3]),
|
||||
deque: deque([1, 2, 3]),
|
||||
torch.Size: torch.Size([1, 2, 3]),
|
||||
immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
|
||||
immutable_list: immutable_list([1, 2, 3]),
|
||||
|
Reference in New Issue
Block a user