[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan
2023-11-07 21:20:24 +08:00
committed by PyTorch MergeBot
parent 5e2adc8650
commit a0d00349ed
6 changed files with 130 additions and 36 deletions

View File

@ -26,6 +26,11 @@ from typing import (
Union,
)
import torch
if torch._running_with_deploy():
raise ImportError("C++ pytree utilities do not work with torch::deploy.")
import optree
from optree import PyTreeSpec # direct import for type annotations
@ -35,6 +40,9 @@ __all__ = [
"Context",
"FlattenFunc",
"UnflattenFunc",
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
"TreeSpec",
"LeafSpec",
"register_pytree_node",
@ -68,6 +76,9 @@ TreeSpec = PyTreeSpec
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
UnflattenFunc = Callable[[Iterable, Context], PyTree]
OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree]
DumpableContext = Any # Any json dumpable text
ToDumpableContextFn = Callable[[Context], DumpableContext]
FromDumpableContextFn = Callable[[DumpableContext], Context]
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
@ -84,6 +95,8 @@ def register_pytree_node(
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
namespace: str = "torch",
) -> None:
"""Extend the set of types that are considered internal nodes in pytrees.
@ -109,6 +122,13 @@ def register_pytree_node(
The function should return an instance of ``cls``.
serialized_type_name (str, optional): A keyword argument used to specify the fully
qualified name used when serializing the tree spec.
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
to convert the context of the pytree to a custom json dumpable representation. This is
used for json serialization, which is being used in :mod:`torch.export` right now.
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
how to convert the custom json dumpable representation of the context back to the
original context. This is used for json deserialization, which is being used in
:mod:`torch.export` right now.
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
type registry. This is used to isolate the registry from other modules that might
register a different custom behavior for the same type. (default: :const:`"torch"`)
@ -193,26 +213,56 @@ def register_pytree_node(
)
)
"""
from ._pytree import _register_pytree_node
_register_pytree_node(
_private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
namespace=namespace,
)
optree.register_pytree_node(
from . import _pytree as python
python._private_register_pytree_node(
cls,
flatten_fn,
_reverse_args(unflatten_fn),
namespace=namespace,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
)
_register_pytree_node = register_pytree_node
def _private_register_pytree_node(
cls: Type[Any],
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
namespace: str = "torch",
) -> None:
"""This is an internal function that is used to register a pytree node type
for the C++ pytree only. End-users should use :func:`register_pytree_node`
instead.
"""
# TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
# PyStructSequence types
if not optree.is_structseq_class(cls):
optree.register_pytree_node(
cls,
flatten_fn,
_reverse_args(unflatten_fn),
namespace=namespace,
)
def tree_flatten(
tree: PyTree,
*,