diff --git a/test/test_pytree.py b/test/test_pytree.py index 1a7380ea5f78..27923b1e568d 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -23,7 +23,6 @@ from torch.testing._internal.common_utils import ( run_tests, skipIfTorchDynamo, subtest, - TEST_WITH_TORCHDYNAMO, TestCase, ) @@ -805,7 +804,6 @@ if "optree" in sys.modules: py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []), ) - @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.") def test_treespec_repr(self): # Check that it looks sane pytree = (0, [0, 0, [0]]) @@ -820,20 +818,6 @@ if "optree" in sys.modules: ), ) - @unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.") - def test_treespec_repr_dynamo(self): - # Check that it looks sane - pytree = (0, [0, 0, [0]]) - _, spec = py_pytree.tree_flatten(pytree) - self.assertExpectedInline( - repr(spec), - """\ -TreeSpec(tuple, None, [*, - TreeSpec(list, None, [*, - *, - TreeSpec(list, None, [*])])])""", - ) - @parametrize( "spec", [ @@ -1365,21 +1349,12 @@ class TestCxxPytree(TestCase): def test_treespec_equality(self): self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec()) - @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.") def test_treespec_repr(self): # Check that it looks sane pytree = (0, [0, 0, [0]]) _, spec = cxx_pytree.tree_flatten(pytree) - self.assertEqual(repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)") - - @unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.") - def test_treespec_repr_dynamo(self): - # Check that it looks sane - pytree = (0, [0, 0, [0]]) - _, spec = cxx_pytree.tree_flatten(pytree) - self.assertExpectedInline( - repr(spec), - "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')", + self.assertEqual( + repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')" ) @parametrize( diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 49f3ae9b75a6..8027a0f84ca6 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -61,6 +61,10 @@ __all__ = [ ] +__TORCH_DICT_SESSION = optree.dict_insertion_ordered(True, namespace="torch") +__TORCH_DICT_SESSION.__enter__() # enable globally and permanently + + T = TypeVar("T") S = TypeVar("S") U = TypeVar("U") @@ -285,20 +289,15 @@ def tree_flatten( >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5} >>> tree_flatten(tree) - ([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)) + ([2, 3, 4, 1, None, 5], PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch')) >>> tree_flatten(1) - ([1], PyTreeSpec(*, NoneIsLeaf)) + ([1], PyTreeSpec(*, NoneIsLeaf, namespace='torch')) >>> tree_flatten(None) - ([None], PyTreeSpec(*, NoneIsLeaf)) - - For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is - dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict` - if you want to keep the keys in the insertion order. - + ([None], PyTreeSpec(*, NoneIsLeaf, namespace='torch')) >>> from collections import OrderedDict >>> tree = OrderedDict([("b", (2, [3, 4])), ("a", 1), ("c", None), ("d", 5)]) >>> tree_flatten(tree) - ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf)) + ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf, namespace='torch')) Args: tree (pytree): A pytree to flatten. @@ -357,7 +356,7 @@ def tree_iter( >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5} >>> list(tree_iter(tree)) - [1, 2, 3, 4, None, 5] + [2, 3, 4, 1, None, 5] >>> list(tree_iter(1)) [1] >>> list(tree_iter(None)) @@ -392,7 +391,7 @@ def tree_leaves( >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5} >>> tree_leaves(tree) - [1, 2, 3, 4, None, 5] + [2, 3, 4, 1, None, 5] >>> tree_leaves(1) [1] >>> tree_leaves(None) @@ -427,11 +426,11 @@ def tree_structure( >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5} >>> tree_structure(tree) - PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) + PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch') >>> tree_structure(1) - PyTreeSpec(*, NoneIsLeaf) + PyTreeSpec(*, NoneIsLeaf, namespace='torch') >>> tree_structure(None) - PyTreeSpec(*, NoneIsLeaf) + PyTreeSpec(*, NoneIsLeaf, namespace='torch') Args: tree (pytree): A pytree to flatten.