mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[pytree][Easy] preserve dict
keys in insertion order in CXX pytree (#130140)
`optree` and JAX pytree traversal the `dict` in sorted key ordering (see [Key Ordering for Dictionaries](https://github.com/metaopt/optree#key-ordering-for-dictionaries)). While in PyTorch Python pytree, we traversal the `dict` in insertion order. See also: - #114392 This aligns the behavior of CXX pytree with Python pytree. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130140 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
1f8ff94d4f
commit
9abaaad6a8
@ -61,6 +61,10 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
__TORCH_DICT_SESSION = optree.dict_insertion_ordered(True, namespace="torch")
|
||||
__TORCH_DICT_SESSION.__enter__() # enable globally and permanently
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
U = TypeVar("U")
|
||||
@ -285,20 +289,15 @@ def tree_flatten(
|
||||
|
||||
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
||||
>>> tree_flatten(tree)
|
||||
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
|
||||
([2, 3, 4, 1, None, 5], PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch'))
|
||||
>>> tree_flatten(1)
|
||||
([1], PyTreeSpec(*, NoneIsLeaf))
|
||||
([1], PyTreeSpec(*, NoneIsLeaf, namespace='torch'))
|
||||
>>> tree_flatten(None)
|
||||
([None], PyTreeSpec(*, NoneIsLeaf))
|
||||
|
||||
For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
|
||||
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
|
||||
if you want to keep the keys in the insertion order.
|
||||
|
||||
([None], PyTreeSpec(*, NoneIsLeaf, namespace='torch'))
|
||||
>>> from collections import OrderedDict
|
||||
>>> tree = OrderedDict([("b", (2, [3, 4])), ("a", 1), ("c", None), ("d", 5)])
|
||||
>>> tree_flatten(tree)
|
||||
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf))
|
||||
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf, namespace='torch'))
|
||||
|
||||
Args:
|
||||
tree (pytree): A pytree to flatten.
|
||||
@ -357,7 +356,7 @@ def tree_iter(
|
||||
|
||||
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
||||
>>> list(tree_iter(tree))
|
||||
[1, 2, 3, 4, None, 5]
|
||||
[2, 3, 4, 1, None, 5]
|
||||
>>> list(tree_iter(1))
|
||||
[1]
|
||||
>>> list(tree_iter(None))
|
||||
@ -392,7 +391,7 @@ def tree_leaves(
|
||||
|
||||
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
||||
>>> tree_leaves(tree)
|
||||
[1, 2, 3, 4, None, 5]
|
||||
[2, 3, 4, 1, None, 5]
|
||||
>>> tree_leaves(1)
|
||||
[1]
|
||||
>>> tree_leaves(None)
|
||||
@ -427,11 +426,11 @@ def tree_structure(
|
||||
|
||||
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
||||
>>> tree_structure(tree)
|
||||
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
|
||||
PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch')
|
||||
>>> tree_structure(1)
|
||||
PyTreeSpec(*, NoneIsLeaf)
|
||||
PyTreeSpec(*, NoneIsLeaf, namespace='torch')
|
||||
>>> tree_structure(None)
|
||||
PyTreeSpec(*, NoneIsLeaf)
|
||||
PyTreeSpec(*, NoneIsLeaf, namespace='torch')
|
||||
|
||||
Args:
|
||||
tree (pytree): A pytree to flatten.
|
||||
|
Reference in New Issue
Block a user