[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)

Changes:

1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan
2023-11-28 15:27:07 +08:00
committed by PyTorch MergeBot
parent 088fc7779e
commit 89a1fe6966
11 changed files with 415 additions and 70 deletions

View File

@ -624,16 +624,23 @@ class TestExport(TestCase):
roundtrip_spec = treespec_loads(treespec_dumps(spec))
self.assertEqual(roundtrip_spec, spec)
# Override the registration with keep none fields
register_dataclass_as_pytree_node(MyDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyDataClass")
@dataclass
class MyOtherDataClass: # the pytree registration don't allow registering the same class twice
x: int
y: int
z: int = None
# Override the registration with keep none fields
register_dataclass_as_pytree_node(MyOtherDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyOtherDataClass")
dt = MyOtherDataClass(x=3, y=4)
flat, spec = tree_flatten(dt)
self.assertEqual(
spec,
TreeSpec(
MyDataClass,
MyOtherDataClass,
(
MyDataClass,
MyOtherDataClass,
['x', 'y', 'z'],
[],
),
@ -643,7 +650,7 @@ class TestExport(TestCase):
self.assertEqual(flat, [3, 4, None])
orig_dt = tree_unflatten(flat, spec)
self.assertTrue(isinstance(orig_dt, MyDataClass))
self.assertTrue(isinstance(orig_dt, MyOtherDataClass))
self.assertEqual(orig_dt.x, 3)
self.assertEqual(orig_dt.y, 4)
self.assertEqual(orig_dt.z, None)

View File

@ -3529,7 +3529,7 @@ class TestFX(JitTestCase):
def f_namedtuple_add(x):
return x.x + x.y
pytree._register_pytree_node(
pytree.register_pytree_node(
Foo,
lambda x: ([x.a, x.b], None),
lambda x, _: Foo(x[0], x[1]),

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: pytree"]
import unittest
from collections import namedtuple, OrderedDict
from collections import namedtuple, OrderedDict, UserDict
import torch
import torch.utils._cxx_pytree as cxx_pytree
@ -26,6 +26,45 @@ class GlobalDummyType:
class TestGenericPytree(TestCase):
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_register_pytree_node(self, pytree_impl):
class MyDict(UserDict):
pass
d = MyDict(a=1, b=2, c=3)
# Custom types are leaf nodes by default
values, spec = pytree_impl.tree_flatten(d)
self.assertEqual(values, [d])
self.assertIs(values[0], d)
self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
self.assertTrue(spec.is_leaf())
# Register MyDict as a pytree node
pytree_impl.register_pytree_node(
MyDict,
lambda d: (list(d.values()), list(d.keys())),
lambda values, keys: MyDict(zip(keys, values)),
)
values, spec = pytree_impl.tree_flatten(d)
self.assertEqual(values, [1, 2, 3])
self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
# Do not allow registering the same type twice
with self.assertRaisesRegex(ValueError, "already registered"):
pytree_impl.register_pytree_node(
MyDict,
lambda d: (list(d.values()), list(d.keys())),
lambda values, keys: MyDict(zip(keys, values)),
)
@parametrize(
"pytree_impl",
[
@ -407,6 +446,28 @@ class TestGenericPytree(TestCase):
class TestPythonPytree(TestCase):
def test_deprecated_register_pytree_node(self):
class DummyType:
def __init__(self, x, y):
self.x = x
self.y = y
with self.assertWarnsRegex(
UserWarning, "torch.utils._pytree._register_pytree_node"
):
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
)
with self.assertWarnsRegex(UserWarning, "already registered"):
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
)
def test_treespec_equality(self):
self.assertTrue(
py_pytree.LeafSpec() == py_pytree.LeafSpec(),
@ -540,7 +601,7 @@ TreeSpec(tuple, None, [*,
self.x = x
self.y = y
py_pytree._register_pytree_node(
py_pytree.register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
@ -560,7 +621,7 @@ TreeSpec(tuple, None, [*,
self.x = x
self.y = y
py_pytree._register_pytree_node(
py_pytree.register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
@ -585,7 +646,7 @@ TreeSpec(tuple, None, [*,
with self.assertRaisesRegex(
ValueError, "Both to_dumpable_context and from_dumpable_context"
):
py_pytree._register_pytree_node(
py_pytree.register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
@ -599,7 +660,7 @@ TreeSpec(tuple, None, [*,
self.x = x
self.y = y
py_pytree._register_pytree_node(
py_pytree.register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),

View File

@ -63,16 +63,16 @@ def register_dataclass_as_pytree_node(
flatten_fn: Optional[FlattenFunc] = None,
unflatten_fn: Optional[UnflattenFunc] = None,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
serialized_type_name: Optional[str] = None,
return_none_fields: bool = False,
) -> None:
assert dataclasses.is_dataclass(
cls
), f"Only dataclasses can be registered with this function: {cls}"
serialized_type = f"{cls.__module__}.{cls.__name__}"
serialized_type = f"{cls.__module__}.{cls.__qualname__}"
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:

View File

@ -29,7 +29,7 @@ from torch._logging import getArtifactLogger
from torch._subclasses import FakeTensor, FakeTensorMode
from torch._subclasses.fake_tensor import is_fake
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
from torch.fx import immutable_collections, Interpreter
from torch.fx import Interpreter
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
from torch.fx.experimental.symbolic_shapes import (
ShapeEnv, is_concrete_int, fx_placeholder_vals, definitely_true, definitely_false, sym_eq
@ -95,19 +95,6 @@ OutputType = Enum(
)
)
pytree._register_pytree_node(
immutable_collections.immutable_list,
lambda x: (list(x), None),
lambda x, c: immutable_collections.immutable_list(x),
)
pytree._register_pytree_node(
immutable_collections.immutable_dict,
lambda x: (list(x.values()), list(x.keys())),
lambda x, c: immutable_collections.immutable_dict(
dict(zip(c, x))
),
)
def partial_asdict(obj: Any) -> Any:
if dataclasses.is_dataclass(obj):
return {field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)}

View File

@ -49,7 +49,7 @@ CONSTANT_NUMEL_LIMIT = 1
# We currently convert all SymInt to proxies before we use them.
# This could plausibly be handled at the Dynamo level.
pytree._register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs))
pytree.register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs))
def fake_signature(fn, nargs):
"""FX gets confused by varargs, de-confuse it"""

View File

@ -1,7 +1,7 @@
from typing import Any, Dict, Iterable, List, Tuple
from ._compatibility import compatibility
from torch.utils._pytree import Context, _register_pytree_node
from torch.utils._pytree import Context, register_pytree_node
__all__ = ["immutable_list", "immutable_dict"]
@ -50,5 +50,5 @@ def _immutable_list_unflatten(values: Iterable[Any], context: Context) -> List[A
return immutable_list(values)
_register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten)
_register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten)
register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten)
register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten)

View File

@ -40,7 +40,11 @@ class _PyTreeExtensionContext:
def __enter__(self):
for class_type, (flatten_func, unflatten_func) in self._extensions.items():
pytree._register_pytree_node(class_type, flatten_func, unflatten_func)
pytree._private_register_pytree_node(
class_type,
flatten_func,
unflatten_func,
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
@ -93,8 +97,11 @@ class _PyTreeExtensionContext:
# All 'ModelOutput' subclasses are defined under module 'modeling_outputs'.
named_model_output_classes = inspect.getmembers(
modeling_outputs,
lambda x: inspect.isclass(x)
and issubclass(x, modeling_outputs.ModelOutput),
lambda x: (
inspect.isclass(x)
and issubclass(x, modeling_outputs.ModelOutput)
and x is not modeling_outputs.ModelOutput
),
)
for _, class_type in named_model_output_classes:

View File

@ -13,7 +13,7 @@ def pytree_register_structseq(cls):
def structseq_unflatten(values, context):
return cls(values)
torch.utils._pytree._register_pytree_node(cls, structseq_flatten, structseq_unflatten)
torch.utils._pytree.register_pytree_node(cls, structseq_flatten, structseq_unflatten)
for name in dir(return_types):
if name.startswith('__'):

View File

@ -13,6 +13,7 @@ collection support for PyTorch APIs.
"""
import functools
import warnings
from typing import (
Any,
Callable,
@ -26,6 +27,11 @@ from typing import (
Union,
)
import torch
if torch._running_with_deploy():
raise ImportError("C++ pytree utilities do not work with torch::deploy.")
import optree
from optree import PyTreeSpec # direct import for type annotations
@ -35,6 +41,9 @@ __all__ = [
"Context",
"FlattenFunc",
"UnflattenFunc",
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
"TreeSpec",
"LeafSpec",
"register_pytree_node",
@ -68,6 +77,9 @@ TreeSpec = PyTreeSpec
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
DumpableContext = Any # Any json dumpable text
ToDumpableContextFn = Callable[[Context], DumpableContext]
FromDumpableContextFn = Callable[[DumpableContext], Context]
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
@ -84,9 +96,11 @@ def register_pytree_node(
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
namespace: str = "torch",
) -> None:
"""Extend the set of types that are considered internal nodes in pytrees.
"""Register a container-like type as pytree node.
The ``namespace`` argument is used to avoid collisions that occur when different libraries
register the same Python type with different behaviors. It is recommended to add a unique prefix
@ -109,6 +123,13 @@ def register_pytree_node(
The function should return an instance of ``cls``.
serialized_type_name (str, optional): A keyword argument used to specify the fully
qualified name used when serializing the tree spec.
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
to convert the context of the pytree to a custom json dumpable representation. This is
used for json serialization, which is being used in :mod:`torch.export` right now.
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
how to convert the custom json dumpable representation of the context back to the
original context. This is used for json deserialization, which is being used in
:mod:`torch.export` right now.
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
type registry. This is used to isolate the registry from other modules that might
register a different custom behavior for the same type. (default: :const:`"torch"`)
@ -193,24 +214,192 @@ def register_pytree_node(
)
)
"""
from ._pytree import _register_pytree_node
_register_pytree_node(
_private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
namespace=namespace,
)
optree.register_pytree_node(
from . import _pytree as python
python._private_register_pytree_node(
cls,
flatten_fn,
_reverse_args(unflatten_fn),
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
)
def _register_pytree_node(
cls: Type[Any],
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
namespace: str = "torch",
) -> None:
"""Register a container-like type as pytree node for the C++ pytree only.
The ``namespace`` argument is used to avoid collisions that occur when different libraries
register the same Python type with different behaviors. It is recommended to add a unique prefix
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
the same class in different namespaces for different use cases.
.. warning::
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
prevent accidental collisions between different libraries that may register the same type.
Args:
cls (type): A Python type to treat as an internal pytree node.
flatten_fn (callable): A function to be used during flattening, taking an instance of
``cls`` and returning a pair, with (1) an iterable for the children to be flattened
recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
passed to the ``unflatten_fn``.
unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
The function should return an instance of ``cls``.
serialized_type_name (str, optional): A keyword argument used to specify the fully
qualified name used when serializing the tree spec.
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
to convert the context of the pytree to a custom json dumpable representation. This is
used for json serialization, which is being used in :mod:`torch.export` right now.
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
how to convert the custom json dumpable representation of the context back to the
original context. This is used for json deserialization, which is being used in
:mod:`torch.export` right now.
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
type registry. This is used to isolate the registry from other modules that might
register a different custom behavior for the same type. (default: :const:`"torch"`)
Example::
>>> # xdoctest: +SKIP
>>> # Registry a Python type with lambda functions
>>> register_pytree_node(
... set,
... lambda s: (sorted(s), None, None),
... lambda children, _: set(children),
... namespace='set',
... )
>>> # xdoctest: +SKIP
>>> # Register a Python type into a namespace
>>> import torch
>>> register_pytree_node(
... torch.Tensor,
... flatten_func=lambda tensor: (
... (tensor.cpu().detach().numpy(),),
... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
... ),
... unflatten_func=lambda children, metadata: torch.tensor(children[0], **metadata),
... namespace='torch2numpy',
... )
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
>>> tree
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # Flatten without specifying the namespace
>>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes # xdoctest: +SKIP
([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
>>> # xdoctest: +SKIP
>>> # Flatten with the namespace
>>> tree_flatten(tree, namespace='torch2numpy') # xdoctest: +SKIP
(
[array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
PyTreeSpec(
{
'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]),
'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*])
},
namespace='torch2numpy'
)
)
>>> # xdoctest: +SKIP
>>> # Register the same type with a different namespace for different behaviors
>>> def tensor2flatparam(tensor):
... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
...
>>> def flatparam2tensor(children, metadata):
... return children[0].reshape(metadata)
...
>>> register_pytree_node(
... torch.Tensor,
... flatten_func=tensor2flatparam,
... unflatten_func=flatparam2tensor,
... namespace='tensor2flatparam',
... )
>>> # xdoctest: +SKIP
>>> # Flatten with the new namespace
>>> tree_flatten(tree, namespace='tensor2flatparam') # xdoctest: +SKIP
(
[
Parameter containing: tensor([0., 0.], requires_grad=True),
Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True)
],
PyTreeSpec(
{
'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]),
'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*])
},
namespace='tensor2flatparam'
)
)
"""
warnings.warn(
"torch.utils._cxx_pytree._register_pytree_node is deprecated. "
"Please use torch.utils._cxx_pytree.register_pytree_node instead.",
stacklevel=2,
)
_private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
namespace=namespace,
)
_register_pytree_node = register_pytree_node
def _private_register_pytree_node(
cls: Type[Any],
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
namespace: str = "torch",
) -> None:
"""This is an internal function that is used to register a pytree node type
for the C++ pytree only. End-users should use :func:`register_pytree_node`
instead.
"""
# TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
# PyStructSequence types
if not optree.is_structseq_class(cls):
optree.register_pytree_node(
cls,
flatten_fn,
_reverse_args(unflatten_fn),
namespace=namespace,
)
def tree_flatten(

View File

@ -17,6 +17,7 @@ To improve the performance we can move parts of the implementation to C++.
import dataclasses
import json
import threading
import warnings
from collections import deque, namedtuple, OrderedDict
from typing import (
@ -99,6 +100,7 @@ class NodeDef(NamedTuple):
unflatten_fn: UnflattenFunc
_NODE_REGISTRY_LOCK = threading.Lock()
SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
@ -120,18 +122,17 @@ SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {}
SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {}
def _register_pytree_node(
def register_pytree_node(
cls: Any,
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
to_str_fn: Optional[ToStrFunc] = None, # deprecated
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
) -> None:
"""
"""Register a container-like type as pytree node.
Args:
cls: the type to register
flatten_fn: A callable that takes a pytree and returns a flattened
@ -150,39 +151,132 @@ def _register_pytree_node(
back to the original context. This is used for json deserialization,
which is being used in torch.export right now.
"""
with _NODE_REGISTRY_LOCK:
if cls in SUPPORTED_NODES:
raise ValueError(f"{cls} is already registered as pytree node.")
_private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
)
try:
from . import _cxx_pytree as cxx
except ImportError:
pass
else:
cxx._private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
)
def _register_pytree_node(
cls: Any,
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
to_str_fn: Optional[ToStrFunc] = None, # deprecated
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
) -> None:
"""Register a container-like type as pytree node for the Python pytree only.
Args:
cls: the type to register
flatten_fn: A callable that takes a pytree and returns a flattened
representation of the pytree and additional context to represent the
flattened pytree.
unflatten_fn: A callable that takes a flattened version of the pytree,
additional context, and returns an unflattened pytree.
serialized_type_name: A keyword argument used to specify the fully qualified
name used when serializing the tree spec.
to_dumpable_context: An optional keyword argument to custom specify how
to convert the context of the pytree to a custom json dumpable
representation. This is used for json serialization, which is being
used in torch.export right now.
from_dumpable_context: An optional keyword argument to custom specify how
to convert the custom json dumpable representation of the context
back to the original context. This is used for json deserialization,
which is being used in torch.export right now.
"""
warnings.warn(
"torch.utils._pytree._register_pytree_node is deprecated. "
"Please use torch.utils._pytree.register_pytree_node instead.",
stacklevel=2,
)
if to_str_fn is not None or maybe_from_str_fn is not None:
warnings.warn(
"to_str_fn and maybe_from_str_fn is deprecated. "
"Please use to_dumpable_context and from_dumpable_context instead."
)
node_def = NodeDef(
_private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
)
SUPPORTED_NODES[cls] = node_def
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
raise ValueError(
f"Both to_dumpable_context and from_dumpable_context for {cls} must "
"be None or registered."
def _private_register_pytree_node(
cls: Any,
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
) -> None:
"""This is an internal function that is used to register a pytree node type
for the Python pytree only. End-users should use :func:`register_pytree_node`
instead.
"""
with _NODE_REGISTRY_LOCK:
if cls in SUPPORTED_NODES:
# TODO: change this warning to an error after OSS/internal stabilize
warnings.warn(
f"{cls} is already registered as pytree node. "
"Overwriting the previous registration.",
)
node_def = NodeDef(
cls,
flatten_fn,
unflatten_fn,
)
SUPPORTED_NODES[cls] = node_def
if serialized_type_name is None:
serialized_type_name = f"{cls.__module__}.{cls.__name__}"
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
raise ValueError(
f"Both to_dumpable_context and from_dumpable_context for {cls} must "
"be None or registered."
)
serialize_node_def = _SerializeNodeDef(
cls,
serialized_type_name,
to_dumpable_context,
from_dumpable_context,
)
SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def
SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls
if serialized_type_name is None:
serialized_type_name = f"{cls.__module__}.{cls.__qualname__}"
register_pytree_node = _register_pytree_node
serialize_node_def = _SerializeNodeDef(
cls,
serialized_type_name,
to_dumpable_context,
from_dumpable_context,
)
SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def
SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls
def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
@ -243,25 +337,25 @@ def _odict_unflatten(
return OrderedDict((key, value) for key, value in zip(context, values))
_register_pytree_node(
_private_register_pytree_node(
dict,
_dict_flatten,
_dict_unflatten,
serialized_type_name="builtins.dict",
)
_register_pytree_node(
_private_register_pytree_node(
list,
_list_flatten,
_list_unflatten,
serialized_type_name="builtins.list",
)
_register_pytree_node(
_private_register_pytree_node(
tuple,
_tuple_flatten,
_tuple_unflatten,
serialized_type_name="builtins.tuple",
)
_register_pytree_node(
_private_register_pytree_node(
namedtuple,
_namedtuple_flatten,
_namedtuple_unflatten,
@ -269,7 +363,7 @@ _register_pytree_node(
from_dumpable_context=_namedtuple_deserialize,
serialized_type_name="collections.namedtuple",
)
_register_pytree_node(
_private_register_pytree_node(
OrderedDict,
_odict_flatten,
_odict_unflatten,
@ -729,7 +823,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
if treespec.type not in SUPPORTED_SERIALIZED_TYPES:
raise NotImplementedError(
f"Serializing {treespec.type} in pytree is not registered."
f"Serializing {treespec.type} in pytree is not registered.",
)
serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type]