[pytree][Easy] preserve dict keys in insertion order in CXX pytree (#130140)

`optree` and JAX pytree traversal the `dict` in sorted key ordering (see [Key Ordering for Dictionaries](https://github.com/metaopt/optree#key-ordering-for-dictionaries)). While in PyTorch Python pytree, we traversal the `dict` in insertion order. See also:

- #114392

This aligns the behavior of CXX pytree with Python pytree.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130140
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan
2025-02-09 22:38:00 +08:00
committed by PyTorch MergeBot
parent 1f8ff94d4f
commit 9abaaad6a8
2 changed files with 15 additions and 41 deletions

View File

@ -61,6 +61,10 @@ __all__ = [
]
__TORCH_DICT_SESSION = optree.dict_insertion_ordered(True, namespace="torch")
__TORCH_DICT_SESSION.__enter__() # enable globally and permanently
T = TypeVar("T")
S = TypeVar("S")
U = TypeVar("U")
@ -285,20 +289,15 @@ def tree_flatten(
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> tree_flatten(tree)
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
([2, 3, 4, 1, None, 5], PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch'))
>>> tree_flatten(1)
([1], PyTreeSpec(*, NoneIsLeaf))
([1], PyTreeSpec(*, NoneIsLeaf, namespace='torch'))
>>> tree_flatten(None)
([None], PyTreeSpec(*, NoneIsLeaf))
For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
if you want to keep the keys in the insertion order.
([None], PyTreeSpec(*, NoneIsLeaf, namespace='torch'))
>>> from collections import OrderedDict
>>> tree = OrderedDict([("b", (2, [3, 4])), ("a", 1), ("c", None), ("d", 5)])
>>> tree_flatten(tree)
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf))
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf, namespace='torch'))
Args:
tree (pytree): A pytree to flatten.
@ -357,7 +356,7 @@ def tree_iter(
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> list(tree_iter(tree))
[1, 2, 3, 4, None, 5]
[2, 3, 4, 1, None, 5]
>>> list(tree_iter(1))
[1]
>>> list(tree_iter(None))
@ -392,7 +391,7 @@ def tree_leaves(
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> tree_leaves(tree)
[1, 2, 3, 4, None, 5]
[2, 3, 4, 1, None, 5]
>>> tree_leaves(1)
[1]
>>> tree_leaves(None)
@ -427,11 +426,11 @@ def tree_structure(
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> tree_structure(tree)
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch')
>>> tree_structure(1)
PyTreeSpec(*, NoneIsLeaf)
PyTreeSpec(*, NoneIsLeaf, namespace='torch')
>>> tree_structure(None)
PyTreeSpec(*, NoneIsLeaf)
PyTreeSpec(*, NoneIsLeaf, namespace='torch')
Args:
tree (pytree): A pytree to flatten.