[dynamo][pytree][2/N] make CXX pytree traceable: tree_flatten / tree_unflatten / tree_structure (#137398)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137398
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan
2024-12-09 18:24:38 +08:00
committed by PyTorch MergeBot
parent c85323c5e8
commit 7edeb1005a
5 changed files with 242 additions and 41 deletions

View File

@ -27,7 +27,7 @@ from typing import (
TypeVar,
Union,
)
from typing_extensions import deprecated
from typing_extensions import deprecated, TypeIs
import optree
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
@ -240,6 +240,10 @@ def _private_register_pytree_node(
)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
return isinstance(obj, TreeSpec)
def tree_is_leaf(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -345,10 +349,10 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
``treespec``.
"""
if not isinstance(treespec, TreeSpec):
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(values, spec): Expected `spec` to be instance of "
f"TreeSpec but got item of type {type(treespec)}."
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
@ -891,7 +895,7 @@ def _broadcast_to_and_flatten(
treespec: TreeSpec,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Optional[List[Any]]:
assert isinstance(treespec, TreeSpec)
assert _is_pytreespec_instance(treespec)
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
try:
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
@ -901,10 +905,10 @@ def _broadcast_to_and_flatten(
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
"""Serialize a treespec to a JSON string."""
if not isinstance(treespec, TreeSpec):
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"treespec_dumps(spec): Expected `spec` to be instance of "
f"TreeSpec but got item of type {type(treespec)}."
f"treespec_dumps(treespec): Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
dummy_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
@ -938,7 +942,7 @@ def treespec_pprint(treespec: TreeSpec) -> str:
class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc]
def __instancecheck__(self, instance: object) -> bool:
return isinstance(instance, TreeSpec) and instance.is_leaf()
return _is_pytreespec_instance(instance) and instance.is_leaf()
class LeafSpec(TreeSpec, metaclass=LeafSpecMeta):