mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][pytree][1/N] make CXX pytree traceable: tree_iter
/ tree_leaves
(#137397)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137397 Approved by: https://github.com/jansel ghstack dependencies: #141360
This commit is contained in:
committed by
PyTorch MergeBot
parent
cdde73033e
commit
07850bb2c1
@ -30,10 +30,10 @@ from typing import (
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import optree
|
||||
from optree import PyTreeSpec # direct import for type annotations
|
||||
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||
|
||||
import torch.utils._pytree as _pytree
|
||||
from torch.utils._pytree import KeyEntry
|
||||
import torch.utils._pytree as python_pytree
|
||||
from torch.utils._pytree import KeyEntry as KeyEntry
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -79,7 +79,6 @@ R = TypeVar("R")
|
||||
|
||||
Context = Any
|
||||
PyTree = Any
|
||||
TreeSpec = PyTreeSpec
|
||||
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
|
||||
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
|
||||
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
|
||||
@ -151,9 +150,7 @@ def register_pytree_node(
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
)
|
||||
|
||||
from . import _pytree as python
|
||||
|
||||
python._private_register_pytree_node(
|
||||
python_pytree._private_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
@ -871,24 +868,19 @@ def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
|
||||
f"treespec_dumps(spec): Expected `spec` to be instance of "
|
||||
f"TreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
from ._pytree import (
|
||||
tree_structure as _tree_structure,
|
||||
treespec_dumps as _treespec_dumps,
|
||||
)
|
||||
|
||||
orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec))
|
||||
return _treespec_dumps(orig_treespec, protocol=protocol)
|
||||
dummy_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
|
||||
orig_treespec = python_pytree.tree_structure(dummy_tree)
|
||||
return python_pytree.treespec_dumps(orig_treespec, protocol=protocol)
|
||||
|
||||
|
||||
def treespec_loads(serialized: str) -> TreeSpec:
|
||||
"""Deserialize a treespec from a JSON string."""
|
||||
from ._pytree import (
|
||||
tree_unflatten as _tree_unflatten,
|
||||
treespec_loads as _treespec_loads,
|
||||
orig_treespec = python_pytree.treespec_loads(serialized)
|
||||
dummy_tree = python_pytree.tree_unflatten(
|
||||
[0] * orig_treespec.num_leaves,
|
||||
orig_treespec,
|
||||
)
|
||||
|
||||
orig_treespec = _treespec_loads(serialized)
|
||||
dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec)
|
||||
treespec = tree_structure(dummy_tree)
|
||||
return treespec
|
||||
|
||||
@ -1002,6 +994,10 @@ def key_get(obj: Any, kp: KeyPath) -> Any:
|
||||
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
|
||||
|
||||
|
||||
_pytree._cxx_pytree_imported = True
|
||||
for args, kwargs in _pytree._cxx_pytree_pending_imports:
|
||||
_private_register_pytree_node(*args, **kwargs)
|
||||
with python_pytree._NODE_REGISTRY_LOCK:
|
||||
python_pytree._cxx_pytree_imported = True
|
||||
args, kwargs = (), {} # type: ignore[var-annotated]
|
||||
for args, kwargs in python_pytree._cxx_pytree_pending_imports:
|
||||
_private_register_pytree_node(*args, **kwargs)
|
||||
python_pytree._cxx_pytree_pending_imports.clear()
|
||||
del args, kwargs
|
||||
|
Reference in New Issue
Block a user