[dynamo][pytree][1/N] make CXX pytree traceable: tree_iter / tree_leaves (#137397)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137397
Approved by: https://github.com/jansel
ghstack dependencies: #141360
This commit is contained in:
Xuehai Pan
2024-11-27 02:54:50 +08:00
committed by PyTorch MergeBot
parent cdde73033e
commit 07850bb2c1
7 changed files with 139 additions and 57 deletions

View File

@ -30,10 +30,10 @@ from typing import (
from typing_extensions import deprecated
import optree
from optree import PyTreeSpec # direct import for type annotations
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
import torch.utils._pytree as _pytree
from torch.utils._pytree import KeyEntry
import torch.utils._pytree as python_pytree
from torch.utils._pytree import KeyEntry as KeyEntry
__all__ = [
@ -79,7 +79,6 @@ R = TypeVar("R")
Context = Any
PyTree = Any
TreeSpec = PyTreeSpec
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
@ -151,9 +150,7 @@ def register_pytree_node(
from_dumpable_context=from_dumpable_context,
)
from . import _pytree as python
python._private_register_pytree_node(
python_pytree._private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
@ -871,24 +868,19 @@ def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
f"treespec_dumps(spec): Expected `spec` to be instance of "
f"TreeSpec but got item of type {type(treespec)}."
)
from ._pytree import (
tree_structure as _tree_structure,
treespec_dumps as _treespec_dumps,
)
orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec))
return _treespec_dumps(orig_treespec, protocol=protocol)
dummy_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
orig_treespec = python_pytree.tree_structure(dummy_tree)
return python_pytree.treespec_dumps(orig_treespec, protocol=protocol)
def treespec_loads(serialized: str) -> TreeSpec:
"""Deserialize a treespec from a JSON string."""
from ._pytree import (
tree_unflatten as _tree_unflatten,
treespec_loads as _treespec_loads,
orig_treespec = python_pytree.treespec_loads(serialized)
dummy_tree = python_pytree.tree_unflatten(
[0] * orig_treespec.num_leaves,
orig_treespec,
)
orig_treespec = _treespec_loads(serialized)
dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec)
treespec = tree_structure(dummy_tree)
return treespec
@ -1002,6 +994,10 @@ def key_get(obj: Any, kp: KeyPath) -> Any:
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
_pytree._cxx_pytree_imported = True
for args, kwargs in _pytree._cxx_pytree_pending_imports:
_private_register_pytree_node(*args, **kwargs)
with python_pytree._NODE_REGISTRY_LOCK:
python_pytree._cxx_pytree_imported = True
args, kwargs = (), {} # type: ignore[var-annotated]
for args, kwargs in python_pytree._cxx_pytree_pending_imports:
_private_register_pytree_node(*args, **kwargs)
python_pytree._cxx_pytree_pending_imports.clear()
del args, kwargs