mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)"
This reverts commit 4e4a6ad6ecd71a1aefde3992ecf7f77e37d2e264. Reverted https://github.com/pytorch/pytorch/pull/112111 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/112111#issuecomment-1824099658))
This commit is contained in:
@ -13,7 +13,6 @@ collection support for PyTorch APIs.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -27,11 +26,6 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
if torch._running_with_deploy():
|
||||
raise ImportError("C++ pytree utilities do not work with torch::deploy.")
|
||||
|
||||
import optree
|
||||
from optree import PyTreeSpec # direct import for type annotations
|
||||
|
||||
@ -41,9 +35,6 @@ __all__ = [
|
||||
"Context",
|
||||
"FlattenFunc",
|
||||
"UnflattenFunc",
|
||||
"DumpableContext",
|
||||
"ToDumpableContextFn",
|
||||
"FromDumpableContextFn",
|
||||
"TreeSpec",
|
||||
"LeafSpec",
|
||||
"register_pytree_node",
|
||||
@ -77,9 +68,6 @@ TreeSpec = PyTreeSpec
|
||||
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
|
||||
UnflattenFunc = Callable[[Iterable, Context], PyTree]
|
||||
OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree]
|
||||
DumpableContext = Any # Any json dumpable text
|
||||
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
||||
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
||||
|
||||
|
||||
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
|
||||
@ -96,11 +84,9 @@ def register_pytree_node(
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
namespace: str = "torch",
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node.
|
||||
"""Extend the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
The ``namespace`` argument is used to avoid collisions that occur when different libraries
|
||||
register the same Python type with different behaviors. It is recommended to add a unique prefix
|
||||
@ -123,13 +109,6 @@ def register_pytree_node(
|
||||
The function should return an instance of ``cls``.
|
||||
serialized_type_name (str, optional): A keyword argument used to specify the fully
|
||||
qualified name used when serializing the tree spec.
|
||||
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
|
||||
to convert the context of the pytree to a custom json dumpable representation. This is
|
||||
used for json serialization, which is being used in :mod:`torch.export` right now.
|
||||
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
|
||||
how to convert the custom json dumpable representation of the context back to the
|
||||
original context. This is used for json deserialization, which is being used in
|
||||
:mod:`torch.export` right now.
|
||||
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
|
||||
type registry. This is used to isolate the registry from other modules that might
|
||||
register a different custom behavior for the same type. (default: :const:`"torch"`)
|
||||
@ -214,192 +193,24 @@ def register_pytree_node(
|
||||
)
|
||||
)
|
||||
"""
|
||||
_private_register_pytree_node(
|
||||
from ._pytree import _register_pytree_node
|
||||
|
||||
_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
from . import _pytree as python
|
||||
|
||||
python._private_register_pytree_node(
|
||||
optree.register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
)
|
||||
|
||||
|
||||
def _register_pytree_node(
|
||||
cls: Type[Any],
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
namespace: str = "torch",
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node for the C++ pytree only.
|
||||
|
||||
The ``namespace`` argument is used to avoid collisions that occur when different libraries
|
||||
register the same Python type with different behaviors. It is recommended to add a unique prefix
|
||||
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
|
||||
the same class in different namespaces for different use cases.
|
||||
|
||||
.. warning::
|
||||
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
|
||||
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
|
||||
prevent accidental collisions between different libraries that may register the same type.
|
||||
|
||||
Args:
|
||||
cls (type): A Python type to treat as an internal pytree node.
|
||||
flatten_fn (callable): A function to be used during flattening, taking an instance of
|
||||
``cls`` and returning a pair, with (1) an iterable for the children to be flattened
|
||||
recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
|
||||
passed to the ``unflatten_fn``.
|
||||
unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
|
||||
returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
|
||||
The function should return an instance of ``cls``.
|
||||
serialized_type_name (str, optional): A keyword argument used to specify the fully
|
||||
qualified name used when serializing the tree spec.
|
||||
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
|
||||
to convert the context of the pytree to a custom json dumpable representation. This is
|
||||
used for json serialization, which is being used in :mod:`torch.export` right now.
|
||||
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
|
||||
how to convert the custom json dumpable representation of the context back to the
|
||||
original context. This is used for json deserialization, which is being used in
|
||||
:mod:`torch.export` right now.
|
||||
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
|
||||
type registry. This is used to isolate the registry from other modules that might
|
||||
register a different custom behavior for the same type. (default: :const:`"torch"`)
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Registry a Python type with lambda functions
|
||||
>>> register_pytree_node(
|
||||
... set,
|
||||
... lambda s: (sorted(s), None, None),
|
||||
... lambda children, _: set(children),
|
||||
... namespace='set',
|
||||
... )
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Register a Python type into a namespace
|
||||
>>> import torch
|
||||
>>> register_pytree_node(
|
||||
... torch.Tensor,
|
||||
... flatten_func=lambda tensor: (
|
||||
... (tensor.cpu().detach().numpy(),),
|
||||
... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
|
||||
... ),
|
||||
... unflatten_func=lambda children, metadata: torch.tensor(children[0], **metadata),
|
||||
... namespace='torch2numpy',
|
||||
... )
|
||||
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
|
||||
>>> tree
|
||||
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
|
||||
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> # Flatten without specifying the namespace
|
||||
>>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes # xdoctest: +SKIP
|
||||
([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Flatten with the namespace
|
||||
>>> tree_flatten(tree, namespace='torch2numpy') # xdoctest: +SKIP
|
||||
(
|
||||
[array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
|
||||
PyTreeSpec(
|
||||
{
|
||||
'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]),
|
||||
'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*])
|
||||
},
|
||||
namespace='torch2numpy'
|
||||
)
|
||||
)
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Register the same type with a different namespace for different behaviors
|
||||
>>> def tensor2flatparam(tensor):
|
||||
... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
|
||||
...
|
||||
>>> def flatparam2tensor(children, metadata):
|
||||
... return children[0].reshape(metadata)
|
||||
...
|
||||
>>> register_pytree_node(
|
||||
... torch.Tensor,
|
||||
... flatten_func=tensor2flatparam,
|
||||
... unflatten_func=flatparam2tensor,
|
||||
... namespace='tensor2flatparam',
|
||||
... )
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Flatten with the new namespace
|
||||
>>> tree_flatten(tree, namespace='tensor2flatparam') # xdoctest: +SKIP
|
||||
(
|
||||
[
|
||||
Parameter containing: tensor([0., 0.], requires_grad=True),
|
||||
Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True)
|
||||
],
|
||||
PyTreeSpec(
|
||||
{
|
||||
'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]),
|
||||
'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*])
|
||||
},
|
||||
namespace='tensor2flatparam'
|
||||
)
|
||||
)
|
||||
"""
|
||||
warnings.warn(
|
||||
"torch.utils._cxx_pytree._register_pytree_node is deprecated. "
|
||||
"Please use torch.utils._cxx_pytree.register_pytree_node instead.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
_private_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
_reverse_args(unflatten_fn),
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
def _private_register_pytree_node(
|
||||
cls: Type[Any],
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
namespace: str = "torch",
|
||||
) -> None:
|
||||
"""This is an internal function that is used to register a pytree node type
|
||||
for the C++ pytree only. End-users should use :func:`register_pytree_node`
|
||||
instead.
|
||||
"""
|
||||
# TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
|
||||
# PyStructSequence types
|
||||
if not optree.is_structseq_class(cls):
|
||||
optree.register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
_reverse_args(unflatten_fn),
|
||||
namespace=namespace,
|
||||
)
|
||||
_register_pytree_node = register_pytree_node
|
||||
|
||||
|
||||
def tree_flatten(
|
||||
|
Reference in New Issue
Block a user