mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)"
This reverts commit 4e4a6ad6ecd71a1aefde3992ecf7f77e37d2e264. Reverted https://github.com/pytorch/pytorch/pull/112111 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/112111#issuecomment-1824099658))
This commit is contained in:
@ -623,23 +623,16 @@ class TestExport(TestCase):
|
||||
roundtrip_spec = treespec_loads(treespec_dumps(spec))
|
||||
self.assertEqual(roundtrip_spec, spec)
|
||||
|
||||
@dataclass
|
||||
class MyOtherDataClass: # the pytree registration don't allow registering the same class twice
|
||||
x: int
|
||||
y: int
|
||||
z: int = None
|
||||
|
||||
# Override the registration with keep none fields
|
||||
register_dataclass_as_pytree_node(MyOtherDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyOtherDataClass")
|
||||
register_dataclass_as_pytree_node(MyDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyDataClass")
|
||||
|
||||
dt = MyOtherDataClass(x=3, y=4)
|
||||
flat, spec = tree_flatten(dt)
|
||||
self.assertEqual(
|
||||
spec,
|
||||
TreeSpec(
|
||||
MyOtherDataClass,
|
||||
MyDataClass,
|
||||
(
|
||||
MyOtherDataClass,
|
||||
MyDataClass,
|
||||
['x', 'y', 'z'],
|
||||
[],
|
||||
),
|
||||
@ -649,7 +642,7 @@ class TestExport(TestCase):
|
||||
self.assertEqual(flat, [3, 4, None])
|
||||
|
||||
orig_dt = tree_unflatten(flat, spec)
|
||||
self.assertTrue(isinstance(orig_dt, MyOtherDataClass))
|
||||
self.assertTrue(isinstance(orig_dt, MyDataClass))
|
||||
self.assertEqual(orig_dt.x, 3)
|
||||
self.assertEqual(orig_dt.y, 4)
|
||||
self.assertEqual(orig_dt.z, None)
|
||||
|
@ -3529,7 +3529,7 @@ class TestFX(JitTestCase):
|
||||
def f_namedtuple_add(x):
|
||||
return x.x + x.y
|
||||
|
||||
pytree.register_pytree_node(
|
||||
pytree._register_pytree_node(
|
||||
Foo,
|
||||
lambda x: ([x.a, x.b], None),
|
||||
lambda x, _: Foo(x[0], x[1]),
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["module: pytree"]
|
||||
|
||||
import unittest
|
||||
from collections import namedtuple, OrderedDict, UserDict
|
||||
from collections import namedtuple, OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.utils._cxx_pytree as cxx_pytree
|
||||
@ -26,45 +26,6 @@ class GlobalDummyType:
|
||||
|
||||
|
||||
class TestGenericPytree(TestCase):
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_register_pytree_node(self, pytree_impl):
|
||||
class MyDict(UserDict):
|
||||
pass
|
||||
|
||||
d = MyDict(a=1, b=2, c=3)
|
||||
|
||||
# Custom types are leaf nodes by default
|
||||
values, spec = pytree_impl.tree_flatten(d)
|
||||
self.assertEqual(values, [d])
|
||||
self.assertIs(values[0], d)
|
||||
self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
|
||||
self.assertTrue(spec.is_leaf())
|
||||
|
||||
# Register MyDict as a pytree node
|
||||
pytree_impl.register_pytree_node(
|
||||
MyDict,
|
||||
lambda d: (list(d.values()), list(d.keys())),
|
||||
lambda values, keys: MyDict(zip(keys, values)),
|
||||
)
|
||||
|
||||
values, spec = pytree_impl.tree_flatten(d)
|
||||
self.assertEqual(values, [1, 2, 3])
|
||||
self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
|
||||
|
||||
# Do not allow registering the same type twice
|
||||
with self.assertRaisesRegex(ValueError, "already registered"):
|
||||
pytree_impl.register_pytree_node(
|
||||
MyDict,
|
||||
lambda d: (list(d.values()), list(d.keys())),
|
||||
lambda values, keys: MyDict(zip(keys, values)),
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
@ -446,21 +407,6 @@ class TestGenericPytree(TestCase):
|
||||
|
||||
|
||||
class TestPythonPytree(TestCase):
|
||||
def test_deprecated_register_pytree_node(self):
|
||||
class DummyType:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning, "torch.utils._pytree._register_pytree_node"
|
||||
):
|
||||
py_pytree._register_pytree_node(
|
||||
DummyType,
|
||||
lambda dummy: ([dummy.x, dummy.y], None),
|
||||
lambda xs, _: DummyType(*xs),
|
||||
)
|
||||
|
||||
def test_treespec_equality(self):
|
||||
self.assertTrue(
|
||||
py_pytree.LeafSpec() == py_pytree.LeafSpec(),
|
||||
@ -594,7 +540,7 @@ TreeSpec(tuple, None, [*,
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
py_pytree.register_pytree_node(
|
||||
py_pytree._register_pytree_node(
|
||||
DummyType,
|
||||
lambda dummy: ([dummy.x, dummy.y], None),
|
||||
lambda xs, _: DummyType(*xs),
|
||||
@ -614,7 +560,7 @@ TreeSpec(tuple, None, [*,
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
py_pytree.register_pytree_node(
|
||||
py_pytree._register_pytree_node(
|
||||
DummyType,
|
||||
lambda dummy: ([dummy.x, dummy.y], None),
|
||||
lambda xs, _: DummyType(*xs),
|
||||
@ -639,7 +585,7 @@ TreeSpec(tuple, None, [*,
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Both to_dumpable_context and from_dumpable_context"
|
||||
):
|
||||
py_pytree.register_pytree_node(
|
||||
py_pytree._register_pytree_node(
|
||||
DummyType,
|
||||
lambda dummy: ([dummy.x, dummy.y], None),
|
||||
lambda xs, _: DummyType(*xs),
|
||||
@ -653,7 +599,7 @@ TreeSpec(tuple, None, [*,
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
py_pytree.register_pytree_node(
|
||||
py_pytree._register_pytree_node(
|
||||
DummyType,
|
||||
lambda dummy: ([dummy.x, dummy.y], None),
|
||||
lambda xs, _: DummyType(*xs),
|
||||
|
@ -63,16 +63,16 @@ def register_dataclass_as_pytree_node(
|
||||
flatten_fn: Optional[FlattenFunc] = None,
|
||||
unflatten_fn: Optional[UnflattenFunc] = None,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
return_none_fields: bool = False,
|
||||
) -> None:
|
||||
assert dataclasses.is_dataclass(
|
||||
cls
|
||||
), f"Only dataclasses can be registered with this function: {cls}"
|
||||
|
||||
serialized_type = f"{cls.__module__}.{cls.__qualname__}"
|
||||
serialized_type = f"{cls.__module__}.{cls.__name__}"
|
||||
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls
|
||||
|
||||
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
|
||||
|
@ -29,7 +29,7 @@ from torch._logging import getArtifactLogger
|
||||
from torch._subclasses import FakeTensor, FakeTensorMode
|
||||
from torch._subclasses.fake_tensor import is_fake
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
|
||||
from torch.fx import Interpreter
|
||||
from torch.fx import immutable_collections, Interpreter
|
||||
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
ShapeEnv, is_concrete_int, fx_placeholder_vals, definitely_true, definitely_false, sym_eq
|
||||
@ -95,6 +95,19 @@ OutputType = Enum(
|
||||
)
|
||||
)
|
||||
|
||||
pytree._register_pytree_node(
|
||||
immutable_collections.immutable_list,
|
||||
lambda x: (list(x), None),
|
||||
lambda x, c: immutable_collections.immutable_list(x),
|
||||
)
|
||||
pytree._register_pytree_node(
|
||||
immutable_collections.immutable_dict,
|
||||
lambda x: (list(x.values()), list(x.keys())),
|
||||
lambda x, c: immutable_collections.immutable_dict(
|
||||
dict(zip(c, x))
|
||||
),
|
||||
)
|
||||
|
||||
def partial_asdict(obj: Any) -> Any:
|
||||
if dataclasses.is_dataclass(obj):
|
||||
return {field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)}
|
||||
|
@ -49,7 +49,7 @@ CONSTANT_NUMEL_LIMIT = 1
|
||||
|
||||
# We currently convert all SymInt to proxies before we use them.
|
||||
# This could plausibly be handled at the Dynamo level.
|
||||
pytree.register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs))
|
||||
pytree._register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs))
|
||||
|
||||
def fake_signature(fn, nargs):
|
||||
"""FX gets confused by varargs, de-confuse it"""
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, Iterable, List, Tuple
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from torch.utils._pytree import Context, register_pytree_node
|
||||
from torch.utils._pytree import Context, _register_pytree_node
|
||||
|
||||
__all__ = ["immutable_list", "immutable_dict"]
|
||||
|
||||
@ -50,5 +50,5 @@ def _immutable_list_unflatten(values: Iterable[Any], context: Context) -> List[A
|
||||
return immutable_list(values)
|
||||
|
||||
|
||||
register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten)
|
||||
register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten)
|
||||
_register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten)
|
||||
_register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten)
|
||||
|
@ -40,11 +40,7 @@ class _PyTreeExtensionContext:
|
||||
|
||||
def __enter__(self):
|
||||
for class_type, (flatten_func, unflatten_func) in self._extensions.items():
|
||||
pytree._private_register_pytree_node(
|
||||
class_type,
|
||||
flatten_func,
|
||||
unflatten_func,
|
||||
)
|
||||
pytree._register_pytree_node(class_type, flatten_func, unflatten_func)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
@ -97,11 +93,8 @@ class _PyTreeExtensionContext:
|
||||
# All 'ModelOutput' subclasses are defined under module 'modeling_outputs'.
|
||||
named_model_output_classes = inspect.getmembers(
|
||||
modeling_outputs,
|
||||
lambda x: (
|
||||
inspect.isclass(x)
|
||||
and issubclass(x, modeling_outputs.ModelOutput)
|
||||
and x is not modeling_outputs.ModelOutput
|
||||
),
|
||||
lambda x: inspect.isclass(x)
|
||||
and issubclass(x, modeling_outputs.ModelOutput),
|
||||
)
|
||||
|
||||
for _, class_type in named_model_output_classes:
|
||||
|
@ -13,7 +13,7 @@ def pytree_register_structseq(cls):
|
||||
def structseq_unflatten(values, context):
|
||||
return cls(values)
|
||||
|
||||
torch.utils._pytree.register_pytree_node(cls, structseq_flatten, structseq_unflatten)
|
||||
torch.utils._pytree._register_pytree_node(cls, structseq_flatten, structseq_unflatten)
|
||||
|
||||
for name in dir(return_types):
|
||||
if name.startswith('__'):
|
||||
|
@ -13,7 +13,6 @@ collection support for PyTorch APIs.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -27,11 +26,6 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
if torch._running_with_deploy():
|
||||
raise ImportError("C++ pytree utilities do not work with torch::deploy.")
|
||||
|
||||
import optree
|
||||
from optree import PyTreeSpec # direct import for type annotations
|
||||
|
||||
@ -41,9 +35,6 @@ __all__ = [
|
||||
"Context",
|
||||
"FlattenFunc",
|
||||
"UnflattenFunc",
|
||||
"DumpableContext",
|
||||
"ToDumpableContextFn",
|
||||
"FromDumpableContextFn",
|
||||
"TreeSpec",
|
||||
"LeafSpec",
|
||||
"register_pytree_node",
|
||||
@ -77,9 +68,6 @@ TreeSpec = PyTreeSpec
|
||||
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
|
||||
UnflattenFunc = Callable[[Iterable, Context], PyTree]
|
||||
OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree]
|
||||
DumpableContext = Any # Any json dumpable text
|
||||
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
||||
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
||||
|
||||
|
||||
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
|
||||
@ -96,11 +84,9 @@ def register_pytree_node(
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
namespace: str = "torch",
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node.
|
||||
"""Extend the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
The ``namespace`` argument is used to avoid collisions that occur when different libraries
|
||||
register the same Python type with different behaviors. It is recommended to add a unique prefix
|
||||
@ -123,13 +109,6 @@ def register_pytree_node(
|
||||
The function should return an instance of ``cls``.
|
||||
serialized_type_name (str, optional): A keyword argument used to specify the fully
|
||||
qualified name used when serializing the tree spec.
|
||||
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
|
||||
to convert the context of the pytree to a custom json dumpable representation. This is
|
||||
used for json serialization, which is being used in :mod:`torch.export` right now.
|
||||
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
|
||||
how to convert the custom json dumpable representation of the context back to the
|
||||
original context. This is used for json deserialization, which is being used in
|
||||
:mod:`torch.export` right now.
|
||||
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
|
||||
type registry. This is used to isolate the registry from other modules that might
|
||||
register a different custom behavior for the same type. (default: :const:`"torch"`)
|
||||
@ -214,192 +193,24 @@ def register_pytree_node(
|
||||
)
|
||||
)
|
||||
"""
|
||||
_private_register_pytree_node(
|
||||
from ._pytree import _register_pytree_node
|
||||
|
||||
_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
from . import _pytree as python
|
||||
|
||||
python._private_register_pytree_node(
|
||||
optree.register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
)
|
||||
|
||||
|
||||
def _register_pytree_node(
|
||||
cls: Type[Any],
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
namespace: str = "torch",
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node for the C++ pytree only.
|
||||
|
||||
The ``namespace`` argument is used to avoid collisions that occur when different libraries
|
||||
register the same Python type with different behaviors. It is recommended to add a unique prefix
|
||||
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
|
||||
the same class in different namespaces for different use cases.
|
||||
|
||||
.. warning::
|
||||
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
|
||||
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
|
||||
prevent accidental collisions between different libraries that may register the same type.
|
||||
|
||||
Args:
|
||||
cls (type): A Python type to treat as an internal pytree node.
|
||||
flatten_fn (callable): A function to be used during flattening, taking an instance of
|
||||
``cls`` and returning a pair, with (1) an iterable for the children to be flattened
|
||||
recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
|
||||
passed to the ``unflatten_fn``.
|
||||
unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
|
||||
returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
|
||||
The function should return an instance of ``cls``.
|
||||
serialized_type_name (str, optional): A keyword argument used to specify the fully
|
||||
qualified name used when serializing the tree spec.
|
||||
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
|
||||
to convert the context of the pytree to a custom json dumpable representation. This is
|
||||
used for json serialization, which is being used in :mod:`torch.export` right now.
|
||||
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
|
||||
how to convert the custom json dumpable representation of the context back to the
|
||||
original context. This is used for json deserialization, which is being used in
|
||||
:mod:`torch.export` right now.
|
||||
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
|
||||
type registry. This is used to isolate the registry from other modules that might
|
||||
register a different custom behavior for the same type. (default: :const:`"torch"`)
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Registry a Python type with lambda functions
|
||||
>>> register_pytree_node(
|
||||
... set,
|
||||
... lambda s: (sorted(s), None, None),
|
||||
... lambda children, _: set(children),
|
||||
... namespace='set',
|
||||
... )
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Register a Python type into a namespace
|
||||
>>> import torch
|
||||
>>> register_pytree_node(
|
||||
... torch.Tensor,
|
||||
... flatten_func=lambda tensor: (
|
||||
... (tensor.cpu().detach().numpy(),),
|
||||
... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
|
||||
... ),
|
||||
... unflatten_func=lambda children, metadata: torch.tensor(children[0], **metadata),
|
||||
... namespace='torch2numpy',
|
||||
... )
|
||||
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
|
||||
>>> tree
|
||||
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
|
||||
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> # Flatten without specifying the namespace
|
||||
>>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes # xdoctest: +SKIP
|
||||
([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Flatten with the namespace
|
||||
>>> tree_flatten(tree, namespace='torch2numpy') # xdoctest: +SKIP
|
||||
(
|
||||
[array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
|
||||
PyTreeSpec(
|
||||
{
|
||||
'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]),
|
||||
'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*])
|
||||
},
|
||||
namespace='torch2numpy'
|
||||
)
|
||||
)
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Register the same type with a different namespace for different behaviors
|
||||
>>> def tensor2flatparam(tensor):
|
||||
... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
|
||||
...
|
||||
>>> def flatparam2tensor(children, metadata):
|
||||
... return children[0].reshape(metadata)
|
||||
...
|
||||
>>> register_pytree_node(
|
||||
... torch.Tensor,
|
||||
... flatten_func=tensor2flatparam,
|
||||
... unflatten_func=flatparam2tensor,
|
||||
... namespace='tensor2flatparam',
|
||||
... )
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Flatten with the new namespace
|
||||
>>> tree_flatten(tree, namespace='tensor2flatparam') # xdoctest: +SKIP
|
||||
(
|
||||
[
|
||||
Parameter containing: tensor([0., 0.], requires_grad=True),
|
||||
Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True)
|
||||
],
|
||||
PyTreeSpec(
|
||||
{
|
||||
'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]),
|
||||
'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*])
|
||||
},
|
||||
namespace='tensor2flatparam'
|
||||
)
|
||||
)
|
||||
"""
|
||||
warnings.warn(
|
||||
"torch.utils._cxx_pytree._register_pytree_node is deprecated. "
|
||||
"Please use torch.utils._cxx_pytree.register_pytree_node instead.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
_private_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
_reverse_args(unflatten_fn),
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
def _private_register_pytree_node(
|
||||
cls: Type[Any],
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
namespace: str = "torch",
|
||||
) -> None:
|
||||
"""This is an internal function that is used to register a pytree node type
|
||||
for the C++ pytree only. End-users should use :func:`register_pytree_node`
|
||||
instead.
|
||||
"""
|
||||
# TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
|
||||
# PyStructSequence types
|
||||
if not optree.is_structseq_class(cls):
|
||||
optree.register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
_reverse_args(unflatten_fn),
|
||||
namespace=namespace,
|
||||
)
|
||||
_register_pytree_node = register_pytree_node
|
||||
|
||||
|
||||
def tree_flatten(
|
||||
|
@ -17,7 +17,6 @@ To improve the performance we can move parts of the implementation to C++.
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import threading
|
||||
import warnings
|
||||
from collections import deque, namedtuple, OrderedDict
|
||||
from typing import (
|
||||
@ -100,7 +99,6 @@ class NodeDef(NamedTuple):
|
||||
unflatten_fn: UnflattenFunc
|
||||
|
||||
|
||||
_NODE_REGISTRY_LOCK = threading.Lock()
|
||||
SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
|
||||
|
||||
|
||||
@ -122,59 +120,6 @@ SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {}
|
||||
SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {}
|
||||
|
||||
|
||||
def register_pytree_node(
|
||||
cls: Any,
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node.
|
||||
|
||||
Args:
|
||||
cls: the type to register
|
||||
flatten_fn: A callable that takes a pytree and returns a flattened
|
||||
representation of the pytree and additional context to represent the
|
||||
flattened pytree.
|
||||
unflatten_fn: A callable that takes a flattened version of the pytree,
|
||||
additional context, and returns an unflattened pytree.
|
||||
serialized_type_name: A keyword argument used to specify the fully qualified
|
||||
name used when serializing the tree spec.
|
||||
to_dumpable_context: An optional keyword argument to custom specify how
|
||||
to convert the context of the pytree to a custom json dumpable
|
||||
representation. This is used for json serialization, which is being
|
||||
used in torch.export right now.
|
||||
from_dumpable_context: An optional keyword argument to custom specify how
|
||||
to convert the custom json dumpable representation of the context
|
||||
back to the original context. This is used for json deserialization,
|
||||
which is being used in torch.export right now.
|
||||
"""
|
||||
_private_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
)
|
||||
|
||||
try:
|
||||
from . import _cxx_pytree as cxx
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
cxx._private_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
)
|
||||
|
||||
|
||||
def _register_pytree_node(
|
||||
cls: Any,
|
||||
flatten_fn: FlattenFunc,
|
||||
@ -186,8 +131,7 @@ def _register_pytree_node(
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node for the Python pytree only.
|
||||
|
||||
"""
|
||||
Args:
|
||||
cls: the type to register
|
||||
flatten_fn: A callable that takes a pytree and returns a flattened
|
||||
@ -206,69 +150,39 @@ def _register_pytree_node(
|
||||
back to the original context. This is used for json deserialization,
|
||||
which is being used in torch.export right now.
|
||||
"""
|
||||
warnings.warn(
|
||||
"torch.utils._pytree._register_pytree_node is deprecated. "
|
||||
"Please use torch.utils._pytree.register_pytree_node instead.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if to_str_fn is not None or maybe_from_str_fn is not None:
|
||||
warnings.warn(
|
||||
"to_str_fn and maybe_from_str_fn is deprecated. "
|
||||
"Please use to_dumpable_context and from_dumpable_context instead."
|
||||
)
|
||||
|
||||
_private_register_pytree_node(
|
||||
node_def = NodeDef(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
)
|
||||
SUPPORTED_NODES[cls] = node_def
|
||||
|
||||
|
||||
def _private_register_pytree_node(
|
||||
cls: Any,
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
) -> None:
|
||||
"""This is an internal function that is used to register a pytree node type
|
||||
for the Python pytree only. End-users should use :func:`register_pytree_node`
|
||||
instead.
|
||||
"""
|
||||
with _NODE_REGISTRY_LOCK:
|
||||
if cls in SUPPORTED_NODES:
|
||||
raise ValueError(f"{cls} is already registered as pytree node.")
|
||||
|
||||
node_def = NodeDef(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
|
||||
raise ValueError(
|
||||
f"Both to_dumpable_context and from_dumpable_context for {cls} must "
|
||||
"be None or registered."
|
||||
)
|
||||
SUPPORTED_NODES[cls] = node_def
|
||||
|
||||
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
|
||||
raise ValueError(
|
||||
f"Both to_dumpable_context and from_dumpable_context for {cls} must "
|
||||
"be None or registered."
|
||||
)
|
||||
if serialized_type_name is None:
|
||||
serialized_type_name = f"{cls.__module__}.{cls.__name__}"
|
||||
|
||||
if serialized_type_name is None:
|
||||
serialized_type_name = f"{cls.__module__}.{cls.__qualname__}"
|
||||
serialize_node_def = _SerializeNodeDef(
|
||||
cls,
|
||||
serialized_type_name,
|
||||
to_dumpable_context,
|
||||
from_dumpable_context,
|
||||
)
|
||||
SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def
|
||||
SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls
|
||||
|
||||
serialize_node_def = _SerializeNodeDef(
|
||||
cls,
|
||||
serialized_type_name,
|
||||
to_dumpable_context,
|
||||
from_dumpable_context,
|
||||
)
|
||||
SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def
|
||||
SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls
|
||||
|
||||
register_pytree_node = _register_pytree_node
|
||||
|
||||
|
||||
def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
|
||||
@ -329,25 +243,25 @@ def _odict_unflatten(
|
||||
return OrderedDict((key, value) for key, value in zip(context, values))
|
||||
|
||||
|
||||
_private_register_pytree_node(
|
||||
_register_pytree_node(
|
||||
dict,
|
||||
_dict_flatten,
|
||||
_dict_unflatten,
|
||||
serialized_type_name="builtins.dict",
|
||||
)
|
||||
_private_register_pytree_node(
|
||||
_register_pytree_node(
|
||||
list,
|
||||
_list_flatten,
|
||||
_list_unflatten,
|
||||
serialized_type_name="builtins.list",
|
||||
)
|
||||
_private_register_pytree_node(
|
||||
_register_pytree_node(
|
||||
tuple,
|
||||
_tuple_flatten,
|
||||
_tuple_unflatten,
|
||||
serialized_type_name="builtins.tuple",
|
||||
)
|
||||
_private_register_pytree_node(
|
||||
_register_pytree_node(
|
||||
namedtuple,
|
||||
_namedtuple_flatten,
|
||||
_namedtuple_unflatten,
|
||||
@ -355,7 +269,7 @@ _private_register_pytree_node(
|
||||
from_dumpable_context=_namedtuple_deserialize,
|
||||
serialized_type_name="collections.namedtuple",
|
||||
)
|
||||
_private_register_pytree_node(
|
||||
_register_pytree_node(
|
||||
OrderedDict,
|
||||
_odict_flatten,
|
||||
_odict_unflatten,
|
||||
@ -815,7 +729,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
|
||||
|
||||
if treespec.type not in SUPPORTED_SERIALIZED_TYPES:
|
||||
raise NotImplementedError(
|
||||
f"Serializing {treespec.type} in pytree is not registered.",
|
||||
f"Serializing {treespec.type} in pytree is not registered."
|
||||
)
|
||||
|
||||
serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type]
|
||||
|
Reference in New Issue
Block a user