mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-03 23:45:05 +08:00 
			
		
		
		
	[pytree] align function signature between C++ and Python pytree (#112482)
Change the argument name in C++ and Python pytree APIs. Also add a test to ensure the function signatures are the same in the two implementations. - #112485 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112482 Approved by: https://github.com/zou3519
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							7715b47f44
						
					
				
				
					commit
					4893a2814f
				
			@ -21,21 +21,21 @@ SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS: Dict[str, Type[Any]] = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def register_dataclass_as_pytree_node(
 | 
			
		||||
    typ: Any,
 | 
			
		||||
    cls: Any,
 | 
			
		||||
    flatten_fn: Optional[FlattenFunc] = None,
 | 
			
		||||
    unflatten_fn: Optional[UnflattenFunc] = None,
 | 
			
		||||
    *,
 | 
			
		||||
    serialized_type_name: Optional[str] = None,
 | 
			
		||||
    to_dumpable_context: Optional[ToDumpableContextFn] = None,
 | 
			
		||||
    from_dumpable_context: Optional[FromDumpableContextFn] = None,
 | 
			
		||||
    serialized_type_name: Optional[str] = None,
 | 
			
		||||
    return_none_fields: bool = False,
 | 
			
		||||
) -> None:
 | 
			
		||||
    assert dataclasses.is_dataclass(
 | 
			
		||||
        typ
 | 
			
		||||
    ), f"Only dataclasses can be registered with this function: {typ}"
 | 
			
		||||
        cls
 | 
			
		||||
    ), f"Only dataclasses can be registered with this function: {cls}"
 | 
			
		||||
 | 
			
		||||
    serialized_type = f"{typ.__module__}.{typ.__name__}"
 | 
			
		||||
    SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = typ
 | 
			
		||||
    serialized_type = f"{cls.__module__}.{cls.__name__}"
 | 
			
		||||
    SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls
 | 
			
		||||
 | 
			
		||||
    def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
 | 
			
		||||
        flattened = []
 | 
			
		||||
@ -48,7 +48,7 @@ def register_dataclass_as_pytree_node(
 | 
			
		||||
                flat_names.append(name)
 | 
			
		||||
            else:
 | 
			
		||||
                none_names.append(name)
 | 
			
		||||
        return flattened, (typ, flat_names, none_names)
 | 
			
		||||
        return flattened, (cls, flat_names, none_names)
 | 
			
		||||
 | 
			
		||||
    def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
 | 
			
		||||
        typ, flat_names, none_names = context
 | 
			
		||||
@ -69,7 +69,7 @@ def register_dataclass_as_pytree_node(
 | 
			
		||||
 | 
			
		||||
    if (to_dumpable_context is None) ^ (from_dumpable_context is None):
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"Both to_dumpable_context and from_dumpable_context for {typ} must "
 | 
			
		||||
            f"Both to_dumpable_context and from_dumpable_context for {cls} must "
 | 
			
		||||
            "be None or registered."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -85,7 +85,7 @@ def register_dataclass_as_pytree_node(
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    _register_pytree_node(
 | 
			
		||||
        typ,
 | 
			
		||||
        cls,
 | 
			
		||||
        flatten_fn,
 | 
			
		||||
        unflatten_fn,
 | 
			
		||||
        serialized_type_name=serialized_type_name,
 | 
			
		||||
 | 
			
		||||
@ -570,12 +570,12 @@ def load(
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def register_dataclass(typ: Any) -> None:
 | 
			
		||||
def register_dataclass(cls: Any) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Registers a dataclass as a valid input/output type for :func:`torch.export.export`.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        typ: the dataclass type to register
 | 
			
		||||
        cls: the dataclass type to register
 | 
			
		||||
 | 
			
		||||
    Example::
 | 
			
		||||
 | 
			
		||||
@ -601,4 +601,4 @@ def register_dataclass(typ: Any) -> None:
 | 
			
		||||
 | 
			
		||||
    from torch._export.utils import register_dataclass_as_pytree_node
 | 
			
		||||
 | 
			
		||||
    return register_dataclass_as_pytree_node(typ)
 | 
			
		||||
    return register_dataclass_as_pytree_node(cls)
 | 
			
		||||
 | 
			
		||||
@ -12,12 +12,12 @@ SUPPORTED_NODES: Dict[Type[Any], FlattenFuncSpec] = {}
 | 
			
		||||
SUPPORTED_NODES_EXACT_MATCH: Dict[Type[Any], Optional[FlattenFuncExactMatchSpec]] = {}
 | 
			
		||||
 | 
			
		||||
def register_pytree_flatten_spec(
 | 
			
		||||
    typ: Any,
 | 
			
		||||
    cls: Any,
 | 
			
		||||
    flatten_fn_spec: FlattenFuncSpec,
 | 
			
		||||
    flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None
 | 
			
		||||
) -> None:
 | 
			
		||||
    SUPPORTED_NODES[typ] = flatten_fn_spec
 | 
			
		||||
    SUPPORTED_NODES_EXACT_MATCH[typ] = flatten_fn_exact_match_spec
 | 
			
		||||
    SUPPORTED_NODES[cls] = flatten_fn_spec
 | 
			
		||||
    SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec
 | 
			
		||||
 | 
			
		||||
def tree_flatten_spec(pytree: PyTree, spec: TreeSpec, exact_structural_match=False) -> List[Any]:
 | 
			
		||||
    if isinstance(spec, LeafSpec):
 | 
			
		||||
 | 
			
		||||
@ -58,6 +58,7 @@ __all__ = [
 | 
			
		||||
 | 
			
		||||
T = TypeVar("T")
 | 
			
		||||
S = TypeVar("S")
 | 
			
		||||
U = TypeVar("U")
 | 
			
		||||
R = TypeVar("R")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -79,11 +80,11 @@ def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
 | 
			
		||||
 | 
			
		||||
def register_pytree_node(
 | 
			
		||||
    cls: Type[Any],
 | 
			
		||||
    flatten_func: FlattenFunc,
 | 
			
		||||
    unflatten_func: UnflattenFunc,
 | 
			
		||||
    namespace: str = "torch",
 | 
			
		||||
    flatten_fn: FlattenFunc,
 | 
			
		||||
    unflatten_fn: UnflattenFunc,
 | 
			
		||||
    *,
 | 
			
		||||
    serialized_type_name: Optional[str] = None,
 | 
			
		||||
    namespace: str = "torch",
 | 
			
		||||
) -> None:
 | 
			
		||||
    """Extend the set of types that are considered internal nodes in pytrees.
 | 
			
		||||
 | 
			
		||||
@ -99,20 +100,18 @@ def register_pytree_node(
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        cls (type): A Python type to treat as an internal pytree node.
 | 
			
		||||
        flatten_fn (callable): A function to be used during flattening, taking an instance of ``cls``
 | 
			
		||||
            and returning a triple or optionally a pair, with (1) an iterable for the children to be
 | 
			
		||||
            flattened recursively, and (2) some hashable auxiliary data to be stored in the treespec
 | 
			
		||||
            and to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree
 | 
			
		||||
            path entries to the corresponding children. If the entries are not provided or given by
 | 
			
		||||
            :data:`None`, then `range(len(children))` will be used.
 | 
			
		||||
        unflatten_fn (callable): A function taking two arguments: the auxiliary data that was returned
 | 
			
		||||
            by ``flatten_func`` and stored in the treespec, and the unflattened children. The function
 | 
			
		||||
            should return an instance of ``cls``.
 | 
			
		||||
        namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
 | 
			
		||||
            type registry. This is used to isolate the registry from other modules that might register
 | 
			
		||||
            a different custom behavior for the same type. (default: :const:`"torch"`)
 | 
			
		||||
        flatten_fn (callable): A function to be used during flattening, taking an instance of
 | 
			
		||||
            ``cls`` and returning a pair, with (1) an iterable for the children to be flattened
 | 
			
		||||
            recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
 | 
			
		||||
            passed to the ``unflatten_fn``.
 | 
			
		||||
        unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
 | 
			
		||||
            returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
 | 
			
		||||
            The function should return an instance of ``cls``.
 | 
			
		||||
        serialized_type_name (str, optional): A keyword argument used to specify the fully
 | 
			
		||||
            qualified name used when serializing the tree spec.
 | 
			
		||||
        namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
 | 
			
		||||
            type registry. This is used to isolate the registry from other modules that might
 | 
			
		||||
            register a different custom behavior for the same type. (default: :const:`"torch"`)
 | 
			
		||||
 | 
			
		||||
    Example::
 | 
			
		||||
 | 
			
		||||
@ -198,15 +197,15 @@ def register_pytree_node(
 | 
			
		||||
 | 
			
		||||
    _register_pytree_node(
 | 
			
		||||
        cls,
 | 
			
		||||
        flatten_func,
 | 
			
		||||
        unflatten_func,
 | 
			
		||||
        flatten_fn,
 | 
			
		||||
        unflatten_fn,
 | 
			
		||||
        serialized_type_name=serialized_type_name,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    optree.register_pytree_node(
 | 
			
		||||
        cls,
 | 
			
		||||
        flatten_func,
 | 
			
		||||
        _reverse_args(unflatten_func),
 | 
			
		||||
        flatten_fn,
 | 
			
		||||
        _reverse_args(unflatten_fn),
 | 
			
		||||
        namespace=namespace,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -219,7 +218,7 @@ def tree_flatten(
 | 
			
		||||
    *,
 | 
			
		||||
    none_is_leaf: bool = True,
 | 
			
		||||
    namespace: str = "torch",
 | 
			
		||||
) -> Tuple[List[Any], PyTreeSpec]:
 | 
			
		||||
) -> Tuple[List[Any], TreeSpec]:
 | 
			
		||||
    """Flatten a pytree.
 | 
			
		||||
 | 
			
		||||
    See also :func:`tree_unflatten`.
 | 
			
		||||
@ -269,7 +268,7 @@ def tree_flatten(
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
 | 
			
		||||
def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
 | 
			
		||||
    """Reconstruct a pytree from the treespec and the leaves.
 | 
			
		||||
 | 
			
		||||
    The inverse of :func:`tree_flatten`.
 | 
			
		||||
@ -282,16 +281,16 @@ def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
 | 
			
		||||
    Args:
 | 
			
		||||
        leaves (iterable): The list of leaves to use for reconstruction. The list must match the
 | 
			
		||||
            number of leaves of the treespec.
 | 
			
		||||
        treespec (PyTreeSpec): The treespec to reconstruct.
 | 
			
		||||
        treespec (TreeSpec): The treespec to reconstruct.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        The reconstructed pytree, containing the ``leaves`` placed in the structure described by
 | 
			
		||||
        ``treespec``.
 | 
			
		||||
    """
 | 
			
		||||
    if not isinstance(treespec, PyTreeSpec):
 | 
			
		||||
    if not isinstance(treespec, TreeSpec):
 | 
			
		||||
        raise TypeError(
 | 
			
		||||
            f"tree_unflatten(values, spec): Expected `spec` to be instance of "
 | 
			
		||||
            f"PyTreeSpec but got item of type {type(treespec)}."
 | 
			
		||||
            f"TreeSpec but got item of type {type(treespec)}."
 | 
			
		||||
        )
 | 
			
		||||
    return optree.tree_unflatten(treespec, leaves)  # type: ignore[arg-type]
 | 
			
		||||
 | 
			
		||||
@ -337,7 +336,7 @@ def tree_structure(
 | 
			
		||||
    *,
 | 
			
		||||
    none_is_leaf: bool = True,
 | 
			
		||||
    namespace: str = "torch",
 | 
			
		||||
) -> PyTreeSpec:
 | 
			
		||||
) -> TreeSpec:
 | 
			
		||||
    """Get the treespec for a pytree.
 | 
			
		||||
 | 
			
		||||
    See also :func:`tree_flatten`.
 | 
			
		||||
@ -464,9 +463,11 @@ def tree_map_(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Type2 = Tuple[Type[T], Type[S]]
 | 
			
		||||
Type3 = Tuple[Type[T], Type[S], Type[U]]
 | 
			
		||||
TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
 | 
			
		||||
 | 
			
		||||
Fn2 = Callable[[Union[T, S]], R]
 | 
			
		||||
Fn3 = Callable[[Union[T, S, U]], R]
 | 
			
		||||
Fn = Callable[[T], R]
 | 
			
		||||
FnAny = Callable[[Any], R]
 | 
			
		||||
 | 
			
		||||
@ -480,6 +481,11 @@ def map_only(__type_or_types: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def map_only(__type_or_types: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def map_only(__type_or_types: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
 | 
			
		||||
    ...
 | 
			
		||||
@ -547,6 +553,18 @@ def tree_map_only(
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_map_only(
 | 
			
		||||
    __type_or_types: Type3[T, S, U],
 | 
			
		||||
    func: Fn3[T, S, U, Any],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
    *rests: PyTree,
 | 
			
		||||
    none_is_leaf: bool = True,
 | 
			
		||||
    namespace: str = "torch",
 | 
			
		||||
) -> PyTree:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_map_only(
 | 
			
		||||
    __type_or_types: TypeAny,
 | 
			
		||||
    func: FnAny[Any],
 | 
			
		||||
@ -588,6 +606,18 @@ def tree_map_only_(
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_map_only_(
 | 
			
		||||
    __type_or_types: Type3[T, S, U],
 | 
			
		||||
    func: Fn3[T, S, U, Any],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
    *rests: PyTree,
 | 
			
		||||
    none_is_leaf: bool = True,
 | 
			
		||||
    namespace: str = "torch",
 | 
			
		||||
) -> PyTree:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_map_only_(
 | 
			
		||||
    __type_or_types: TypeAny,
 | 
			
		||||
    func: FnAny[Any],
 | 
			
		||||
@ -651,6 +681,18 @@ def tree_all_only(
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_all_only(
 | 
			
		||||
    __type_or_types: Type3[T, S, U],
 | 
			
		||||
    pred: Fn3[T, S, U, bool],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
    *,
 | 
			
		||||
    none_is_leaf: bool = True,
 | 
			
		||||
    namespace: str = "torch",
 | 
			
		||||
) -> bool:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_all_only(
 | 
			
		||||
    __type_or_types: TypeAny,
 | 
			
		||||
    pred: FnAny[bool],
 | 
			
		||||
@ -687,6 +729,18 @@ def tree_any_only(
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_any_only(
 | 
			
		||||
    __type_or_types: Type3[T, S, U],
 | 
			
		||||
    pred: Fn3[T, S, U, bool],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
    *,
 | 
			
		||||
    none_is_leaf: bool = True,
 | 
			
		||||
    namespace: str = "torch",
 | 
			
		||||
) -> bool:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_any_only(
 | 
			
		||||
    __type_or_types: TypeAny,
 | 
			
		||||
    pred: FnAny[bool],
 | 
			
		||||
@ -764,12 +818,12 @@ def broadcast_prefix(
 | 
			
		||||
# _broadcast_to_and_flatten to check this.
 | 
			
		||||
def _broadcast_to_and_flatten(
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
    treespec: PyTreeSpec,
 | 
			
		||||
    treespec: TreeSpec,
 | 
			
		||||
    *,
 | 
			
		||||
    none_is_leaf: bool = True,
 | 
			
		||||
    namespace: str = "torch",
 | 
			
		||||
) -> Optional[List[Any]]:
 | 
			
		||||
    assert isinstance(treespec, PyTreeSpec)
 | 
			
		||||
    assert isinstance(treespec, TreeSpec)
 | 
			
		||||
    full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
 | 
			
		||||
    try:
 | 
			
		||||
        return broadcast_prefix(
 | 
			
		||||
@ -782,12 +836,12 @@ def _broadcast_to_and_flatten(
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def treespec_dumps(treespec: PyTreeSpec) -> str:
 | 
			
		||||
def treespec_dumps(treespec: TreeSpec) -> str:
 | 
			
		||||
    """Serialize a treespec to a JSON string."""
 | 
			
		||||
    if not isinstance(treespec, PyTreeSpec):
 | 
			
		||||
    if not isinstance(treespec, TreeSpec):
 | 
			
		||||
        raise TypeError(
 | 
			
		||||
            f"treespec_dumps(spec): Expected `spec` to be instance of "
 | 
			
		||||
            f"PyTreeSpec but got item of type {type(treespec)}."
 | 
			
		||||
            f"TreeSpec but got item of type {type(treespec)}."
 | 
			
		||||
        )
 | 
			
		||||
    from ._pytree import (
 | 
			
		||||
        tree_structure as _tree_structure,
 | 
			
		||||
@ -798,7 +852,7 @@ def treespec_dumps(treespec: PyTreeSpec) -> str:
 | 
			
		||||
    return _treespec_dumps(orig_treespec)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def treespec_loads(serialized: str) -> PyTreeSpec:
 | 
			
		||||
def treespec_loads(serialized: str) -> TreeSpec:
 | 
			
		||||
    """Deserialize a treespec from a JSON string."""
 | 
			
		||||
    from ._pytree import (
 | 
			
		||||
        tree_unflatten as _tree_unflatten,
 | 
			
		||||
@ -816,7 +870,7 @@ class _DummyLeaf:
 | 
			
		||||
        return "*"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def treespec_pprint(treespec: PyTreeSpec) -> str:
 | 
			
		||||
def treespec_pprint(treespec: TreeSpec) -> str:
 | 
			
		||||
    dummy_tree = tree_unflatten(
 | 
			
		||||
        [_DummyLeaf() for _ in range(treespec.num_leaves)],
 | 
			
		||||
        treespec,
 | 
			
		||||
@ -824,14 +878,11 @@ def treespec_pprint(treespec: PyTreeSpec) -> str:
 | 
			
		||||
    return repr(dummy_tree)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PyTreeLeafSpecMeta(type(PyTreeSpec)):  # type: ignore[misc]
 | 
			
		||||
class LeafSpecMeta(type(TreeSpec)):  # type: ignore[misc]
 | 
			
		||||
    def __instancecheck__(self, instance: object) -> bool:
 | 
			
		||||
        return isinstance(instance, PyTreeSpec) and instance.is_leaf()
 | 
			
		||||
        return isinstance(instance, TreeSpec) and instance.is_leaf()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PyTreeLeafSpec(PyTreeSpec, metaclass=PyTreeLeafSpecMeta):
 | 
			
		||||
    def __new__(cls, none_is_leaf: bool = True) -> "PyTreeLeafSpec":
 | 
			
		||||
class LeafSpec(TreeSpec, metaclass=LeafSpecMeta):
 | 
			
		||||
    def __new__(cls, none_is_leaf: bool = True) -> "LeafSpec":
 | 
			
		||||
        return optree.treespec_leaf(none_is_leaf=none_is_leaf)  # type: ignore[return-value]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
LeafSpec = PyTreeLeafSpec
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,7 @@ Python values. Furthermore, a pytree should not contain reference cycles.
 | 
			
		||||
 | 
			
		||||
pytrees are useful for working with nested collections of Tensors. For example,
 | 
			
		||||
one can use `tree_map` to map a function over all Tensors inside some nested
 | 
			
		||||
collection of Tensors and `tree_unflatten` to get a flat list of all Tensors
 | 
			
		||||
collection of Tensors and `tree_leaves` to get a flat list of all Tensors
 | 
			
		||||
inside some nested collection. pytrees are helpful for implementing nested
 | 
			
		||||
collection support for PyTorch APIs.
 | 
			
		||||
 | 
			
		||||
@ -121,11 +121,11 @@ SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _register_pytree_node(
 | 
			
		||||
    typ: Any,
 | 
			
		||||
    cls: Any,
 | 
			
		||||
    flatten_fn: FlattenFunc,
 | 
			
		||||
    unflatten_fn: UnflattenFunc,
 | 
			
		||||
    to_str_fn: Optional[ToStrFunc] = None,
 | 
			
		||||
    maybe_from_str_fn: Optional[MaybeFromStrFunc] = None,
 | 
			
		||||
    to_str_fn: Optional[ToStrFunc] = None,  # deprecated
 | 
			
		||||
    maybe_from_str_fn: Optional[MaybeFromStrFunc] = None,  # deprecated
 | 
			
		||||
    *,
 | 
			
		||||
    serialized_type_name: Optional[str] = None,
 | 
			
		||||
    to_dumpable_context: Optional[ToDumpableContextFn] = None,
 | 
			
		||||
@ -133,12 +133,12 @@ def _register_pytree_node(
 | 
			
		||||
) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Args:
 | 
			
		||||
        typ: the type to register
 | 
			
		||||
        cls: the type to register
 | 
			
		||||
        flatten_fn: A callable that takes a pytree and returns a flattened
 | 
			
		||||
            representation of the pytree and additional context to represent the
 | 
			
		||||
            flattened pytree.
 | 
			
		||||
        unflatten_fn: A callable that takes a flattened version of the pytree,
 | 
			
		||||
            additional context, and returns an unflattedn pytree.
 | 
			
		||||
            additional context, and returns an unflattened pytree.
 | 
			
		||||
        serialized_type_name: A keyword argument used to specify the fully qualified
 | 
			
		||||
            name used when serializing the tree spec.
 | 
			
		||||
        to_dumpable_context: An optional keyword argument to custom specify how
 | 
			
		||||
@ -157,26 +157,29 @@ def _register_pytree_node(
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    node_def = NodeDef(
 | 
			
		||||
        typ,
 | 
			
		||||
        cls,
 | 
			
		||||
        flatten_fn,
 | 
			
		||||
        unflatten_fn,
 | 
			
		||||
    )
 | 
			
		||||
    SUPPORTED_NODES[typ] = node_def
 | 
			
		||||
    SUPPORTED_NODES[cls] = node_def
 | 
			
		||||
 | 
			
		||||
    if (to_dumpable_context is None) ^ (from_dumpable_context is None):
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"Both to_dumpable_context and from_dumpable_context for {typ} must "
 | 
			
		||||
            f"Both to_dumpable_context and from_dumpable_context for {cls} must "
 | 
			
		||||
            "be None or registered."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if serialized_type_name is None:
 | 
			
		||||
        serialized_type_name = f"{typ.__module__}.{typ.__name__}"
 | 
			
		||||
        serialized_type_name = f"{cls.__module__}.{cls.__name__}"
 | 
			
		||||
 | 
			
		||||
    serialize_node_def = _SerializeNodeDef(
 | 
			
		||||
        typ, serialized_type_name, to_dumpable_context, from_dumpable_context
 | 
			
		||||
        cls,
 | 
			
		||||
        serialized_type_name,
 | 
			
		||||
        to_dumpable_context,
 | 
			
		||||
        from_dumpable_context,
 | 
			
		||||
    )
 | 
			
		||||
    SUPPORTED_SERIALIZED_TYPES[typ] = serialize_node_def
 | 
			
		||||
    SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = typ
 | 
			
		||||
    SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def
 | 
			
		||||
    SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_pytree_node = _register_pytree_node
 | 
			
		||||
@ -275,8 +278,8 @@ _register_pytree_node(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
 | 
			
		||||
def _is_namedtuple_instance(pytree: Any) -> bool:
 | 
			
		||||
    typ = type(pytree)
 | 
			
		||||
def _is_namedtuple_instance(tree: Any) -> bool:
 | 
			
		||||
    typ = type(tree)
 | 
			
		||||
    bases = typ.__bases__
 | 
			
		||||
    if len(bases) != 1 or bases[0] != tuple:
 | 
			
		||||
        return False
 | 
			
		||||
@ -286,15 +289,15 @@ def _is_namedtuple_instance(pytree: Any) -> bool:
 | 
			
		||||
    return all(type(entry) == str for entry in fields)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_node_type(pytree: Any) -> Any:
 | 
			
		||||
    if _is_namedtuple_instance(pytree):
 | 
			
		||||
def _get_node_type(tree: Any) -> Any:
 | 
			
		||||
    if _is_namedtuple_instance(tree):
 | 
			
		||||
        return namedtuple
 | 
			
		||||
    return type(pytree)
 | 
			
		||||
    return type(tree)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# A leaf is defined as anything that is not a Node.
 | 
			
		||||
def _is_leaf(pytree: PyTree) -> bool:
 | 
			
		||||
    return _get_node_type(pytree) not in SUPPORTED_NODES
 | 
			
		||||
def _is_leaf(tree: PyTree) -> bool:
 | 
			
		||||
    return _get_node_type(tree) not in SUPPORTED_NODES
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# A TreeSpec represents the structure of a pytree. It holds:
 | 
			
		||||
@ -345,109 +348,107 @@ class LeafSpec(TreeSpec):
 | 
			
		||||
_LEAF_SPEC = LeafSpec()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _tree_flatten_helper(pytree: PyTree, out_leaves: List[Any]) -> TreeSpec:
 | 
			
		||||
    if _is_leaf(pytree):
 | 
			
		||||
        out_leaves.append(pytree)
 | 
			
		||||
def _tree_flatten_helper(tree: PyTree, leaves: List[Any]) -> TreeSpec:
 | 
			
		||||
    if _is_leaf(tree):
 | 
			
		||||
        leaves.append(tree)
 | 
			
		||||
        return _LEAF_SPEC
 | 
			
		||||
 | 
			
		||||
    node_type = _get_node_type(pytree)
 | 
			
		||||
    node_type = _get_node_type(tree)
 | 
			
		||||
    flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
 | 
			
		||||
    child_pytrees, context = flatten_fn(pytree)
 | 
			
		||||
    child_pytrees, context = flatten_fn(tree)
 | 
			
		||||
 | 
			
		||||
    # Recursively flatten the children
 | 
			
		||||
    children_specs = [
 | 
			
		||||
        _tree_flatten_helper(child, out_leaves) for child in child_pytrees
 | 
			
		||||
    ]
 | 
			
		||||
    children_specs = [_tree_flatten_helper(child, leaves) for child in child_pytrees]
 | 
			
		||||
 | 
			
		||||
    return TreeSpec(node_type, context, children_specs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]:
 | 
			
		||||
def tree_flatten(tree: PyTree) -> Tuple[List[Any], TreeSpec]:
 | 
			
		||||
    """Flattens a pytree into a list of values and a TreeSpec that can be used
 | 
			
		||||
    to reconstruct the pytree.
 | 
			
		||||
    """
 | 
			
		||||
    leaves: List[Any] = []
 | 
			
		||||
    spec = _tree_flatten_helper(pytree, leaves)
 | 
			
		||||
    spec = _tree_flatten_helper(tree, leaves)
 | 
			
		||||
    return leaves, spec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _tree_leaves_helper(pytree: PyTree, out_leaves: List[Any]) -> None:
 | 
			
		||||
    if _is_leaf(pytree):
 | 
			
		||||
        out_leaves.append(pytree)
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    node_type = _get_node_type(pytree)
 | 
			
		||||
    flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
 | 
			
		||||
    child_pytrees, _ = flatten_fn(pytree)
 | 
			
		||||
 | 
			
		||||
    # Recursively flatten the children
 | 
			
		||||
    for child in child_pytrees:
 | 
			
		||||
        _tree_leaves_helper(child, out_leaves)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_leaves(pytree: PyTree) -> List[Any]:
 | 
			
		||||
    """Get a list of leaves of a pytree."""
 | 
			
		||||
    leaves: List[Any] = []
 | 
			
		||||
    _tree_leaves_helper(pytree, leaves)
 | 
			
		||||
    return leaves
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_structure(pytree: PyTree) -> TreeSpec:
 | 
			
		||||
    """Get the TreeSpec for a pytree."""
 | 
			
		||||
    return tree_flatten(pytree)[1]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_unflatten(values: Iterable[Any], spec: TreeSpec) -> PyTree:
 | 
			
		||||
def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
 | 
			
		||||
    """Given a list of values and a TreeSpec, builds a pytree.
 | 
			
		||||
    This is the inverse operation of `tree_flatten`.
 | 
			
		||||
    """
 | 
			
		||||
    if not isinstance(spec, TreeSpec):
 | 
			
		||||
    if not isinstance(treespec, TreeSpec):
 | 
			
		||||
        raise TypeError(
 | 
			
		||||
            f"tree_unflatten(values, spec): Expected `spec` to be instance of "
 | 
			
		||||
            f"TreeSpec but got item of type {type(spec)}.",
 | 
			
		||||
            f"tree_unflatten(leaves, treespec): Expected `treespec` to be "
 | 
			
		||||
            f"instance of TreeSpec but got item of type {type(treespec)}.",
 | 
			
		||||
        )
 | 
			
		||||
    if not isinstance(values, (list, tuple)):
 | 
			
		||||
        values = list(values)
 | 
			
		||||
    if len(values) != spec.num_leaves:
 | 
			
		||||
    if not isinstance(leaves, (list, tuple)):
 | 
			
		||||
        leaves = list(leaves)
 | 
			
		||||
    if len(leaves) != treespec.num_leaves:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"tree_unflatten(values, spec): `values` has length {len(values)} "
 | 
			
		||||
            f"but the spec refers to a pytree that holds {spec.num_leaves} "
 | 
			
		||||
            f"items ({spec}).",
 | 
			
		||||
            f"tree_unflatten(leaves, treespec): `leaves` has length {len(leaves)} "
 | 
			
		||||
            f"but the spec refers to a pytree that holds {treespec.num_leaves} "
 | 
			
		||||
            f"items ({treespec}).",
 | 
			
		||||
        )
 | 
			
		||||
    if isinstance(spec, LeafSpec):
 | 
			
		||||
        return values[0]
 | 
			
		||||
    if isinstance(treespec, LeafSpec):
 | 
			
		||||
        return leaves[0]
 | 
			
		||||
 | 
			
		||||
    unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn
 | 
			
		||||
    unflatten_fn = SUPPORTED_NODES[treespec.type].unflatten_fn
 | 
			
		||||
 | 
			
		||||
    # Recursively unflatten the children
 | 
			
		||||
    start = 0
 | 
			
		||||
    end = 0
 | 
			
		||||
    child_pytrees = []
 | 
			
		||||
    for child_spec in spec.children_specs:
 | 
			
		||||
    for child_spec in treespec.children_specs:
 | 
			
		||||
        end += child_spec.num_leaves
 | 
			
		||||
        child_pytrees.append(tree_unflatten(values[start:end], child_spec))
 | 
			
		||||
        child_pytrees.append(tree_unflatten(leaves[start:end], child_spec))
 | 
			
		||||
        start = end
 | 
			
		||||
 | 
			
		||||
    return unflatten_fn(child_pytrees, spec.context)
 | 
			
		||||
    return unflatten_fn(child_pytrees, treespec.context)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_map(fn: Any, pytree: PyTree) -> PyTree:
 | 
			
		||||
    flat_args, spec = tree_flatten(pytree)
 | 
			
		||||
    return tree_unflatten([fn(i) for i in flat_args], spec)
 | 
			
		||||
def _tree_leaves_helper(tree: PyTree, leaves: List[Any]) -> None:
 | 
			
		||||
    if _is_leaf(tree):
 | 
			
		||||
        leaves.append(tree)
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    node_type = _get_node_type(tree)
 | 
			
		||||
    flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
 | 
			
		||||
    child_pytrees, _ = flatten_fn(tree)
 | 
			
		||||
 | 
			
		||||
    # Recursively flatten the children
 | 
			
		||||
    for child in child_pytrees:
 | 
			
		||||
        _tree_leaves_helper(child, leaves)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_map_(fn: Any, pytree: PyTree) -> PyTree:
 | 
			
		||||
    flat_args = tree_leaves(pytree)
 | 
			
		||||
    deque(map(fn, flat_args), maxlen=0)  # consume and exhaust the iterable
 | 
			
		||||
    return pytree
 | 
			
		||||
def tree_leaves(tree: PyTree) -> List[Any]:
 | 
			
		||||
    """Get a list of leaves of a pytree."""
 | 
			
		||||
    leaves: List[Any] = []
 | 
			
		||||
    _tree_leaves_helper(tree, leaves)
 | 
			
		||||
    return leaves
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_structure(tree: PyTree) -> TreeSpec:
 | 
			
		||||
    """Get the TreeSpec for a pytree."""
 | 
			
		||||
    return tree_flatten(tree)[1]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_map(func: Any, tree: PyTree) -> PyTree:
 | 
			
		||||
    flat_args, spec = tree_flatten(tree)
 | 
			
		||||
    return tree_unflatten([func(i) for i in flat_args], spec)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_map_(func: Any, tree: PyTree) -> PyTree:
 | 
			
		||||
    flat_args = tree_leaves(tree)
 | 
			
		||||
    deque(map(func, flat_args), maxlen=0)  # consume and exhaust the iterable
 | 
			
		||||
    return tree
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Type2 = Tuple[Type[T], Type[S]]
 | 
			
		||||
Type3 = Tuple[Type[T], Type[S], Type[U]]
 | 
			
		||||
TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
 | 
			
		||||
 | 
			
		||||
Fn3 = Callable[[Union[T, S, U]], R]
 | 
			
		||||
Fn2 = Callable[[Union[T, S]], R]
 | 
			
		||||
Fn3 = Callable[[Union[T, S, U]], R]
 | 
			
		||||
Fn = Callable[[T], R]
 | 
			
		||||
FnAny = Callable[[Any], R]
 | 
			
		||||
 | 
			
		||||
@ -457,22 +458,27 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
 | 
			
		||||
# These specializations help with type inference on the lambda passed to this
 | 
			
		||||
# function
 | 
			
		||||
@overload
 | 
			
		||||
def map_only(ty: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
 | 
			
		||||
def map_only(__type_or_types: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def map_only(ty: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
 | 
			
		||||
def map_only(__type_or_types: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def map_only(__type_or_types: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# This specialization is needed for the implementations below that call
 | 
			
		||||
@overload
 | 
			
		||||
def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]:
 | 
			
		||||
def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]:
 | 
			
		||||
def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
 | 
			
		||||
    """
 | 
			
		||||
    Suppose you are writing a tree_map over tensors, leaving everything
 | 
			
		||||
    else unchanged.  Ordinarily you would have to write:
 | 
			
		||||
@ -492,99 +498,168 @@ def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]:
 | 
			
		||||
    You can also directly use 'tree_map_only'
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def deco(f: Callable[[T], Any]) -> Callable[[Any], Any]:
 | 
			
		||||
        def inner(x: T) -> Any:
 | 
			
		||||
            if isinstance(x, ty):
 | 
			
		||||
                return f(x)
 | 
			
		||||
            else:
 | 
			
		||||
    def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
 | 
			
		||||
        # @functools.wraps(func)  # torch dynamo doesn't support this yet
 | 
			
		||||
        def wrapped(x: T) -> Any:
 | 
			
		||||
            if isinstance(x, __type_or_types):
 | 
			
		||||
                return func(x)
 | 
			
		||||
            return x
 | 
			
		||||
 | 
			
		||||
        return inner
 | 
			
		||||
        return wrapped
 | 
			
		||||
 | 
			
		||||
    return deco
 | 
			
		||||
    return wrapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_map_only(ty: Type[T], fn: Fn[T, Any], pytree: PyTree) -> PyTree:
 | 
			
		||||
def tree_map_only(
 | 
			
		||||
    __type_or_types: Type[T],
 | 
			
		||||
    func: Fn[T, Any],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> PyTree:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_map_only(ty: Type2[T, S], fn: Fn2[T, S, Any], pytree: PyTree) -> PyTree:
 | 
			
		||||
def tree_map_only(
 | 
			
		||||
    __type_or_types: Type2[T, S],
 | 
			
		||||
    func: Fn2[T, S, Any],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> PyTree:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_map_only(ty: Type3[T, S, U], fn: Fn3[T, S, U, Any], pytree: PyTree) -> PyTree:
 | 
			
		||||
def tree_map_only(
 | 
			
		||||
    __type_or_types: Type3[T, S, U],
 | 
			
		||||
    func: Fn3[T, S, U, Any],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> PyTree:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree:
 | 
			
		||||
    return tree_map(map_only(ty)(fn), pytree)
 | 
			
		||||
def tree_map_only(
 | 
			
		||||
    __type_or_types: TypeAny,
 | 
			
		||||
    func: FnAny[Any],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> PyTree:
 | 
			
		||||
    return tree_map(map_only(__type_or_types)(func), tree)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_map_only_(ty: Type[T], fn: Fn[T, Any], pytree: PyTree) -> PyTree:
 | 
			
		||||
def tree_map_only_(
 | 
			
		||||
    __type_or_types: Type[T],
 | 
			
		||||
    func: Fn[T, Any],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> PyTree:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_map_only_(ty: Type2[T, S], fn: Fn2[T, S, Any], pytree: PyTree) -> PyTree:
 | 
			
		||||
def tree_map_only_(
 | 
			
		||||
    __type_or_types: Type2[T, S],
 | 
			
		||||
    func: Fn2[T, S, Any],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> PyTree:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_map_only_(ty: Type3[T, S, U], fn: Fn3[T, S, U, Any], pytree: PyTree) -> PyTree:
 | 
			
		||||
def tree_map_only_(
 | 
			
		||||
    __type_or_types: Type3[T, S, U],
 | 
			
		||||
    func: Fn3[T, S, U, Any],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> PyTree:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_map_only_(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree:
 | 
			
		||||
    return tree_map_(map_only(ty)(fn), pytree)
 | 
			
		||||
def tree_map_only_(
 | 
			
		||||
    __type_or_types: TypeAny,
 | 
			
		||||
    func: FnAny[Any],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> PyTree:
 | 
			
		||||
    return tree_map_(map_only(__type_or_types)(func), tree)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool:
 | 
			
		||||
    flat_args = tree_leaves(pytree)
 | 
			
		||||
def tree_all(pred: Callable[[Any], bool], tree: PyTree) -> bool:
 | 
			
		||||
    flat_args = tree_leaves(tree)
 | 
			
		||||
    return all(map(pred, flat_args))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_any(pred: Callable[[Any], bool], pytree: PyTree) -> bool:
 | 
			
		||||
    flat_args = tree_leaves(pytree)
 | 
			
		||||
def tree_any(pred: Callable[[Any], bool], tree: PyTree) -> bool:
 | 
			
		||||
    flat_args = tree_leaves(tree)
 | 
			
		||||
    return any(map(pred, flat_args))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_all_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool:
 | 
			
		||||
def tree_all_only(
 | 
			
		||||
    __type_or_types: Type[T],
 | 
			
		||||
    pred: Fn[T, bool],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_all_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool:
 | 
			
		||||
def tree_all_only(
 | 
			
		||||
    __type_or_types: Type2[T, S],
 | 
			
		||||
    pred: Fn2[T, S, bool],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_all_only(ty: Type3[T, S, U], pred: Fn3[T, S, U, bool], pytree: PyTree) -> bool:
 | 
			
		||||
def tree_all_only(
 | 
			
		||||
    __type_or_types: Type3[T, S, U],
 | 
			
		||||
    pred: Fn3[T, S, U, bool],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_all_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool:
 | 
			
		||||
    flat_args = tree_leaves(pytree)
 | 
			
		||||
    return all(pred(x) for x in flat_args if isinstance(x, ty))
 | 
			
		||||
def tree_all_only(
 | 
			
		||||
    __type_or_types: TypeAny,
 | 
			
		||||
    pred: FnAny[bool],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    flat_args = tree_leaves(tree)
 | 
			
		||||
    return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_any_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool:
 | 
			
		||||
def tree_any_only(
 | 
			
		||||
    __type_or_types: Type[T],
 | 
			
		||||
    pred: Fn[T, bool],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def tree_any_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool:
 | 
			
		||||
def tree_any_only(
 | 
			
		||||
    __type_or_types: Type2[T, S],
 | 
			
		||||
    pred: Fn2[T, S, bool],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_any_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool:
 | 
			
		||||
    flat_args = tree_leaves(pytree)
 | 
			
		||||
    return any(pred(x) for x in flat_args if isinstance(x, ty))
 | 
			
		||||
@overload
 | 
			
		||||
def tree_any_only(
 | 
			
		||||
    __type_or_types: Type3[T, S, U],
 | 
			
		||||
    pred: Fn3[T, S, U, bool],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_any_only(
 | 
			
		||||
    __type_or_types: TypeAny,
 | 
			
		||||
    pred: FnAny[bool],
 | 
			
		||||
    tree: PyTree,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    flat_args = tree_leaves(tree)
 | 
			
		||||
    return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Broadcasts a pytree to the provided TreeSpec and returns the flattened
 | 
			
		||||
@ -595,27 +670,27 @@ def tree_any_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool:
 | 
			
		||||
# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
 | 
			
		||||
# broadcastable to the tree structure of `inputs` and we use
 | 
			
		||||
# _broadcast_to_and_flatten to check this.
 | 
			
		||||
def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[Any]]:
 | 
			
		||||
    assert isinstance(spec, TreeSpec)
 | 
			
		||||
def _broadcast_to_and_flatten(tree: PyTree, treespec: TreeSpec) -> Optional[List[Any]]:
 | 
			
		||||
    assert isinstance(treespec, TreeSpec)
 | 
			
		||||
 | 
			
		||||
    if _is_leaf(pytree):
 | 
			
		||||
        return [pytree] * spec.num_leaves
 | 
			
		||||
    if isinstance(spec, LeafSpec):
 | 
			
		||||
    if _is_leaf(tree):
 | 
			
		||||
        return [tree] * treespec.num_leaves
 | 
			
		||||
    if isinstance(treespec, LeafSpec):
 | 
			
		||||
        return None
 | 
			
		||||
    node_type = _get_node_type(pytree)
 | 
			
		||||
    if node_type != spec.type:
 | 
			
		||||
    node_type = _get_node_type(tree)
 | 
			
		||||
    if node_type != treespec.type:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
 | 
			
		||||
    child_pytrees, ctx = flatten_fn(pytree)
 | 
			
		||||
    child_pytrees, ctx = flatten_fn(tree)
 | 
			
		||||
 | 
			
		||||
    # Check if the Node is different from the spec
 | 
			
		||||
    if len(child_pytrees) != len(spec.children_specs) or ctx != spec.context:
 | 
			
		||||
    if len(child_pytrees) != len(treespec.children_specs) or ctx != treespec.context:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    # Recursively flatten the children
 | 
			
		||||
    result: List[Any] = []
 | 
			
		||||
    for child, child_spec in zip(child_pytrees, spec.children_specs):
 | 
			
		||||
    for child, child_spec in zip(child_pytrees, treespec.children_specs):
 | 
			
		||||
        flat = _broadcast_to_and_flatten(child, child_spec)
 | 
			
		||||
        if flat is not None:
 | 
			
		||||
            result += flat
 | 
			
		||||
@ -648,23 +723,28 @@ class _ProtocolFn(NamedTuple):
 | 
			
		||||
_SUPPORTED_PROTOCOLS: Dict[int, _ProtocolFn] = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _treespec_to_json(spec: TreeSpec) -> _TreeSpecSchema:
 | 
			
		||||
    if isinstance(spec, LeafSpec):
 | 
			
		||||
def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
 | 
			
		||||
    if isinstance(treespec, LeafSpec):
 | 
			
		||||
        return _TreeSpecSchema(None, None, [])
 | 
			
		||||
 | 
			
		||||
    serialize_node_def = SUPPORTED_SERIALIZED_TYPES[spec.type]
 | 
			
		||||
    if treespec.type not in SUPPORTED_SERIALIZED_TYPES:
 | 
			
		||||
        raise NotImplementedError(
 | 
			
		||||
            f"Serializing {treespec.type} in pytree is not registered."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type]
 | 
			
		||||
 | 
			
		||||
    serialized_type_name = serialize_node_def.serialized_type_name
 | 
			
		||||
 | 
			
		||||
    if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND:
 | 
			
		||||
        raise NotImplementedError(
 | 
			
		||||
            f"No registered serialization name for {spec.type} found. "
 | 
			
		||||
            f"No registered serialization name for {treespec.type} found. "
 | 
			
		||||
            "Please update your _register_pytree_node call with a `serialized_type_name` kwarg."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if serialize_node_def.to_dumpable_context is None:
 | 
			
		||||
        try:
 | 
			
		||||
            serialized_context = json.dumps(spec.context)
 | 
			
		||||
            serialized_context = json.dumps(treespec.context)
 | 
			
		||||
        except TypeError as e:
 | 
			
		||||
            raise TypeError(
 | 
			
		||||
                "Unable to serialize context. "
 | 
			
		||||
@ -672,9 +752,9 @@ def _treespec_to_json(spec: TreeSpec) -> _TreeSpecSchema:
 | 
			
		||||
                "custom serializer using _register_pytree_node."
 | 
			
		||||
            ) from e
 | 
			
		||||
    else:
 | 
			
		||||
        serialized_context = serialize_node_def.to_dumpable_context(spec.context)
 | 
			
		||||
        serialized_context = serialize_node_def.to_dumpable_context(treespec.context)
 | 
			
		||||
 | 
			
		||||
    child_schemas = [_treespec_to_json(child) for child in spec.children_specs]
 | 
			
		||||
    child_schemas = [_treespec_to_json(child) for child in treespec.children_specs]
 | 
			
		||||
 | 
			
		||||
    return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)
 | 
			
		||||
 | 
			
		||||
@ -764,9 +844,9 @@ def treespec_pprint(treespec: TreeSpec) -> str:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO(angelayi): remove this function after OSS/internal stabilize
 | 
			
		||||
def pytree_to_str(spec: TreeSpec) -> str:
 | 
			
		||||
def pytree_to_str(treespec: TreeSpec) -> str:
 | 
			
		||||
    warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps")
 | 
			
		||||
    return treespec_dumps(spec)
 | 
			
		||||
    return treespec_dumps(treespec)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO(angelayi): remove this function after OSS/internal stabilize
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user