Revert "[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)"

This reverts commit 4e4a6ad6ecd71a1aefde3992ecf7f77e37d2e264.

Reverted https://github.com/pytorch/pytorch/pull/112111 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/112111#issuecomment-1824099658))
This commit is contained in:
PyTorch MergeBot
2023-11-23 09:59:29 +00:00
parent a76bb5d84d
commit 01366efcc9
11 changed files with 66 additions and 396 deletions

View File

@ -13,7 +13,6 @@ collection support for PyTorch APIs.
"""
import functools
import warnings
from typing import (
Any,
Callable,
@ -27,11 +26,6 @@ from typing import (
Union,
)
import torch
if torch._running_with_deploy():
raise ImportError("C++ pytree utilities do not work with torch::deploy.")
import optree
from optree import PyTreeSpec # direct import for type annotations
@ -41,9 +35,6 @@ __all__ = [
"Context",
"FlattenFunc",
"UnflattenFunc",
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
"TreeSpec",
"LeafSpec",
"register_pytree_node",
@ -77,9 +68,6 @@ TreeSpec = PyTreeSpec
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
UnflattenFunc = Callable[[Iterable, Context], PyTree]
OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree]
DumpableContext = Any # Any json dumpable text
ToDumpableContextFn = Callable[[Context], DumpableContext]
FromDumpableContextFn = Callable[[DumpableContext], Context]
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
@ -96,11 +84,9 @@ def register_pytree_node(
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
namespace: str = "torch",
) -> None:
"""Register a container-like type as pytree node.
"""Extend the set of types that are considered internal nodes in pytrees.
The ``namespace`` argument is used to avoid collisions that occur when different libraries
register the same Python type with different behaviors. It is recommended to add a unique prefix
@ -123,13 +109,6 @@ def register_pytree_node(
The function should return an instance of ``cls``.
serialized_type_name (str, optional): A keyword argument used to specify the fully
qualified name used when serializing the tree spec.
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
to convert the context of the pytree to a custom json dumpable representation. This is
used for json serialization, which is being used in :mod:`torch.export` right now.
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
how to convert the custom json dumpable representation of the context back to the
original context. This is used for json deserialization, which is being used in
:mod:`torch.export` right now.
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
type registry. This is used to isolate the registry from other modules that might
register a different custom behavior for the same type. (default: :const:`"torch"`)
@ -214,192 +193,24 @@ def register_pytree_node(
)
)
"""
_private_register_pytree_node(
from ._pytree import _register_pytree_node
_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
namespace=namespace,
)
from . import _pytree as python
python._private_register_pytree_node(
optree.register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
)
def _register_pytree_node(
cls: Type[Any],
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
namespace: str = "torch",
) -> None:
"""Register a container-like type as pytree node for the C++ pytree only.
The ``namespace`` argument is used to avoid collisions that occur when different libraries
register the same Python type with different behaviors. It is recommended to add a unique prefix
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
the same class in different namespaces for different use cases.
.. warning::
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
prevent accidental collisions between different libraries that may register the same type.
Args:
cls (type): A Python type to treat as an internal pytree node.
flatten_fn (callable): A function to be used during flattening, taking an instance of
``cls`` and returning a pair, with (1) an iterable for the children to be flattened
recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
passed to the ``unflatten_fn``.
unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
The function should return an instance of ``cls``.
serialized_type_name (str, optional): A keyword argument used to specify the fully
qualified name used when serializing the tree spec.
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
to convert the context of the pytree to a custom json dumpable representation. This is
used for json serialization, which is being used in :mod:`torch.export` right now.
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
how to convert the custom json dumpable representation of the context back to the
original context. This is used for json deserialization, which is being used in
:mod:`torch.export` right now.
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
type registry. This is used to isolate the registry from other modules that might
register a different custom behavior for the same type. (default: :const:`"torch"`)
Example::
>>> # xdoctest: +SKIP
>>> # Registry a Python type with lambda functions
>>> register_pytree_node(
... set,
... lambda s: (sorted(s), None, None),
... lambda children, _: set(children),
... namespace='set',
... )
>>> # xdoctest: +SKIP
>>> # Register a Python type into a namespace
>>> import torch
>>> register_pytree_node(
... torch.Tensor,
... flatten_func=lambda tensor: (
... (tensor.cpu().detach().numpy(),),
... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
... ),
... unflatten_func=lambda children, metadata: torch.tensor(children[0], **metadata),
... namespace='torch2numpy',
... )
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
>>> tree
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # Flatten without specifying the namespace
>>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes # xdoctest: +SKIP
([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
>>> # xdoctest: +SKIP
>>> # Flatten with the namespace
>>> tree_flatten(tree, namespace='torch2numpy') # xdoctest: +SKIP
(
[array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
PyTreeSpec(
{
'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]),
'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*])
},
namespace='torch2numpy'
)
)
>>> # xdoctest: +SKIP
>>> # Register the same type with a different namespace for different behaviors
>>> def tensor2flatparam(tensor):
... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
...
>>> def flatparam2tensor(children, metadata):
... return children[0].reshape(metadata)
...
>>> register_pytree_node(
... torch.Tensor,
... flatten_func=tensor2flatparam,
... unflatten_func=flatparam2tensor,
... namespace='tensor2flatparam',
... )
>>> # xdoctest: +SKIP
>>> # Flatten with the new namespace
>>> tree_flatten(tree, namespace='tensor2flatparam') # xdoctest: +SKIP
(
[
Parameter containing: tensor([0., 0.], requires_grad=True),
Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True)
],
PyTreeSpec(
{
'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]),
'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*])
},
namespace='tensor2flatparam'
)
)
"""
warnings.warn(
"torch.utils._cxx_pytree._register_pytree_node is deprecated. "
"Please use torch.utils._cxx_pytree.register_pytree_node instead.",
stacklevel=2,
)
_private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
_reverse_args(unflatten_fn),
namespace=namespace,
)
def _private_register_pytree_node(
cls: Type[Any],
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
namespace: str = "torch",
) -> None:
"""This is an internal function that is used to register a pytree node type
for the C++ pytree only. End-users should use :func:`register_pytree_node`
instead.
"""
# TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
# PyStructSequence types
if not optree.is_structseq_class(cls):
optree.register_pytree_node(
cls,
flatten_fn,
_reverse_args(unflatten_fn),
namespace=namespace,
)
_register_pytree_node = register_pytree_node
def tree_flatten(