mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][pytree][2/N] make CXX pytree traceable: tree_flatten
/ tree_unflatten
/ tree_structure
(#137398)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137398 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
c85323c5e8
commit
7edeb1005a
@ -27,7 +27,7 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import deprecated
|
||||
from typing_extensions import deprecated, TypeIs
|
||||
|
||||
import optree
|
||||
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||
@ -240,6 +240,10 @@ def _private_register_pytree_node(
|
||||
)
|
||||
|
||||
|
||||
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
|
||||
return isinstance(obj, TreeSpec)
|
||||
|
||||
|
||||
def tree_is_leaf(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
@ -345,10 +349,10 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
|
||||
``treespec``.
|
||||
"""
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
f"tree_unflatten(values, spec): Expected `spec` to be instance of "
|
||||
f"TreeSpec but got item of type {type(treespec)}."
|
||||
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
|
||||
|
||||
@ -891,7 +895,7 @@ def _broadcast_to_and_flatten(
|
||||
treespec: TreeSpec,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Optional[List[Any]]:
|
||||
assert isinstance(treespec, TreeSpec)
|
||||
assert _is_pytreespec_instance(treespec)
|
||||
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
|
||||
try:
|
||||
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
|
||||
@ -901,10 +905,10 @@ def _broadcast_to_and_flatten(
|
||||
|
||||
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
|
||||
"""Serialize a treespec to a JSON string."""
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
f"treespec_dumps(spec): Expected `spec` to be instance of "
|
||||
f"TreeSpec but got item of type {type(treespec)}."
|
||||
f"treespec_dumps(treespec): Expected `treespec` to be instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
|
||||
dummy_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
|
||||
@ -938,7 +942,7 @@ def treespec_pprint(treespec: TreeSpec) -> str:
|
||||
|
||||
class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc]
|
||||
def __instancecheck__(self, instance: object) -> bool:
|
||||
return isinstance(instance, TreeSpec) and instance.is_leaf()
|
||||
return _is_pytreespec_instance(instance) and instance.is_leaf()
|
||||
|
||||
|
||||
class LeafSpec(TreeSpec, metaclass=LeafSpecMeta):
|
||||
|
Reference in New Issue
Block a user