mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
5e2adc8650
commit
a0d00349ed
@ -26,6 +26,11 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
if torch._running_with_deploy():
|
||||
raise ImportError("C++ pytree utilities do not work with torch::deploy.")
|
||||
|
||||
import optree
|
||||
from optree import PyTreeSpec # direct import for type annotations
|
||||
|
||||
@ -35,6 +40,9 @@ __all__ = [
|
||||
"Context",
|
||||
"FlattenFunc",
|
||||
"UnflattenFunc",
|
||||
"DumpableContext",
|
||||
"ToDumpableContextFn",
|
||||
"FromDumpableContextFn",
|
||||
"TreeSpec",
|
||||
"LeafSpec",
|
||||
"register_pytree_node",
|
||||
@ -68,6 +76,9 @@ TreeSpec = PyTreeSpec
|
||||
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
|
||||
UnflattenFunc = Callable[[Iterable, Context], PyTree]
|
||||
OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree]
|
||||
DumpableContext = Any # Any json dumpable text
|
||||
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
||||
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
||||
|
||||
|
||||
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
|
||||
@ -84,6 +95,8 @@ def register_pytree_node(
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
namespace: str = "torch",
|
||||
) -> None:
|
||||
"""Extend the set of types that are considered internal nodes in pytrees.
|
||||
@ -109,6 +122,13 @@ def register_pytree_node(
|
||||
The function should return an instance of ``cls``.
|
||||
serialized_type_name (str, optional): A keyword argument used to specify the fully
|
||||
qualified name used when serializing the tree spec.
|
||||
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
|
||||
to convert the context of the pytree to a custom json dumpable representation. This is
|
||||
used for json serialization, which is being used in :mod:`torch.export` right now.
|
||||
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
|
||||
how to convert the custom json dumpable representation of the context back to the
|
||||
original context. This is used for json deserialization, which is being used in
|
||||
:mod:`torch.export` right now.
|
||||
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
|
||||
type registry. This is used to isolate the registry from other modules that might
|
||||
register a different custom behavior for the same type. (default: :const:`"torch"`)
|
||||
@ -193,26 +213,56 @@ def register_pytree_node(
|
||||
)
|
||||
)
|
||||
"""
|
||||
from ._pytree import _register_pytree_node
|
||||
|
||||
_register_pytree_node(
|
||||
_private_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
optree.register_pytree_node(
|
||||
from . import _pytree as python
|
||||
|
||||
python._private_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
_reverse_args(unflatten_fn),
|
||||
namespace=namespace,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
)
|
||||
|
||||
|
||||
_register_pytree_node = register_pytree_node
|
||||
|
||||
|
||||
def _private_register_pytree_node(
|
||||
cls: Type[Any],
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
namespace: str = "torch",
|
||||
) -> None:
|
||||
"""This is an internal function that is used to register a pytree node type
|
||||
for the C++ pytree only. End-users should use :func:`register_pytree_node`
|
||||
instead.
|
||||
"""
|
||||
# TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
|
||||
# PyStructSequence types
|
||||
if not optree.is_structseq_class(cls):
|
||||
optree.register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
_reverse_args(unflatten_fn),
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
def tree_flatten(
|
||||
tree: PyTree,
|
||||
*,
|
||||
|
Reference in New Issue
Block a user