mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Changes: 1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree. 2. Do not allow registering a type as pytree node twice in the Python pytree. 3. Add thread lock to the Python pytree node register API. 4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning. 5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations. 6. Add tests to ensure a warning will be raised when the old private function is called. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
088fc7779e
commit
89a1fe6966
@ -624,16 +624,23 @@ class TestExport(TestCase):
|
|||||||
roundtrip_spec = treespec_loads(treespec_dumps(spec))
|
roundtrip_spec = treespec_loads(treespec_dumps(spec))
|
||||||
self.assertEqual(roundtrip_spec, spec)
|
self.assertEqual(roundtrip_spec, spec)
|
||||||
|
|
||||||
# Override the registration with keep none fields
|
@dataclass
|
||||||
register_dataclass_as_pytree_node(MyDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyDataClass")
|
class MyOtherDataClass: # the pytree registration don't allow registering the same class twice
|
||||||
|
x: int
|
||||||
|
y: int
|
||||||
|
z: int = None
|
||||||
|
|
||||||
|
# Override the registration with keep none fields
|
||||||
|
register_dataclass_as_pytree_node(MyOtherDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyOtherDataClass")
|
||||||
|
|
||||||
|
dt = MyOtherDataClass(x=3, y=4)
|
||||||
flat, spec = tree_flatten(dt)
|
flat, spec = tree_flatten(dt)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
spec,
|
spec,
|
||||||
TreeSpec(
|
TreeSpec(
|
||||||
MyDataClass,
|
MyOtherDataClass,
|
||||||
(
|
(
|
||||||
MyDataClass,
|
MyOtherDataClass,
|
||||||
['x', 'y', 'z'],
|
['x', 'y', 'z'],
|
||||||
[],
|
[],
|
||||||
),
|
),
|
||||||
@ -643,7 +650,7 @@ class TestExport(TestCase):
|
|||||||
self.assertEqual(flat, [3, 4, None])
|
self.assertEqual(flat, [3, 4, None])
|
||||||
|
|
||||||
orig_dt = tree_unflatten(flat, spec)
|
orig_dt = tree_unflatten(flat, spec)
|
||||||
self.assertTrue(isinstance(orig_dt, MyDataClass))
|
self.assertTrue(isinstance(orig_dt, MyOtherDataClass))
|
||||||
self.assertEqual(orig_dt.x, 3)
|
self.assertEqual(orig_dt.x, 3)
|
||||||
self.assertEqual(orig_dt.y, 4)
|
self.assertEqual(orig_dt.y, 4)
|
||||||
self.assertEqual(orig_dt.z, None)
|
self.assertEqual(orig_dt.z, None)
|
||||||
|
@ -3529,7 +3529,7 @@ class TestFX(JitTestCase):
|
|||||||
def f_namedtuple_add(x):
|
def f_namedtuple_add(x):
|
||||||
return x.x + x.y
|
return x.x + x.y
|
||||||
|
|
||||||
pytree._register_pytree_node(
|
pytree.register_pytree_node(
|
||||||
Foo,
|
Foo,
|
||||||
lambda x: ([x.a, x.b], None),
|
lambda x: ([x.a, x.b], None),
|
||||||
lambda x, _: Foo(x[0], x[1]),
|
lambda x, _: Foo(x[0], x[1]),
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Owner(s): ["module: pytree"]
|
# Owner(s): ["module: pytree"]
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from collections import namedtuple, OrderedDict
|
from collections import namedtuple, OrderedDict, UserDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._cxx_pytree as cxx_pytree
|
import torch.utils._cxx_pytree as cxx_pytree
|
||||||
@ -26,6 +26,45 @@ class GlobalDummyType:
|
|||||||
|
|
||||||
|
|
||||||
class TestGenericPytree(TestCase):
|
class TestGenericPytree(TestCase):
|
||||||
|
@parametrize(
|
||||||
|
"pytree_impl",
|
||||||
|
[
|
||||||
|
subtest(py_pytree, name="py"),
|
||||||
|
subtest(cxx_pytree, name="cxx"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_register_pytree_node(self, pytree_impl):
|
||||||
|
class MyDict(UserDict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
d = MyDict(a=1, b=2, c=3)
|
||||||
|
|
||||||
|
# Custom types are leaf nodes by default
|
||||||
|
values, spec = pytree_impl.tree_flatten(d)
|
||||||
|
self.assertEqual(values, [d])
|
||||||
|
self.assertIs(values[0], d)
|
||||||
|
self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
|
||||||
|
self.assertTrue(spec.is_leaf())
|
||||||
|
|
||||||
|
# Register MyDict as a pytree node
|
||||||
|
pytree_impl.register_pytree_node(
|
||||||
|
MyDict,
|
||||||
|
lambda d: (list(d.values()), list(d.keys())),
|
||||||
|
lambda values, keys: MyDict(zip(keys, values)),
|
||||||
|
)
|
||||||
|
|
||||||
|
values, spec = pytree_impl.tree_flatten(d)
|
||||||
|
self.assertEqual(values, [1, 2, 3])
|
||||||
|
self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
|
||||||
|
|
||||||
|
# Do not allow registering the same type twice
|
||||||
|
with self.assertRaisesRegex(ValueError, "already registered"):
|
||||||
|
pytree_impl.register_pytree_node(
|
||||||
|
MyDict,
|
||||||
|
lambda d: (list(d.values()), list(d.keys())),
|
||||||
|
lambda values, keys: MyDict(zip(keys, values)),
|
||||||
|
)
|
||||||
|
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"pytree_impl",
|
"pytree_impl",
|
||||||
[
|
[
|
||||||
@ -407,6 +446,28 @@ class TestGenericPytree(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class TestPythonPytree(TestCase):
|
class TestPythonPytree(TestCase):
|
||||||
|
def test_deprecated_register_pytree_node(self):
|
||||||
|
class DummyType:
|
||||||
|
def __init__(self, x, y):
|
||||||
|
self.x = x
|
||||||
|
self.y = y
|
||||||
|
|
||||||
|
with self.assertWarnsRegex(
|
||||||
|
UserWarning, "torch.utils._pytree._register_pytree_node"
|
||||||
|
):
|
||||||
|
py_pytree._register_pytree_node(
|
||||||
|
DummyType,
|
||||||
|
lambda dummy: ([dummy.x, dummy.y], None),
|
||||||
|
lambda xs, _: DummyType(*xs),
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertWarnsRegex(UserWarning, "already registered"):
|
||||||
|
py_pytree._register_pytree_node(
|
||||||
|
DummyType,
|
||||||
|
lambda dummy: ([dummy.x, dummy.y], None),
|
||||||
|
lambda xs, _: DummyType(*xs),
|
||||||
|
)
|
||||||
|
|
||||||
def test_treespec_equality(self):
|
def test_treespec_equality(self):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
py_pytree.LeafSpec() == py_pytree.LeafSpec(),
|
py_pytree.LeafSpec() == py_pytree.LeafSpec(),
|
||||||
@ -540,7 +601,7 @@ TreeSpec(tuple, None, [*,
|
|||||||
self.x = x
|
self.x = x
|
||||||
self.y = y
|
self.y = y
|
||||||
|
|
||||||
py_pytree._register_pytree_node(
|
py_pytree.register_pytree_node(
|
||||||
DummyType,
|
DummyType,
|
||||||
lambda dummy: ([dummy.x, dummy.y], None),
|
lambda dummy: ([dummy.x, dummy.y], None),
|
||||||
lambda xs, _: DummyType(*xs),
|
lambda xs, _: DummyType(*xs),
|
||||||
@ -560,7 +621,7 @@ TreeSpec(tuple, None, [*,
|
|||||||
self.x = x
|
self.x = x
|
||||||
self.y = y
|
self.y = y
|
||||||
|
|
||||||
py_pytree._register_pytree_node(
|
py_pytree.register_pytree_node(
|
||||||
DummyType,
|
DummyType,
|
||||||
lambda dummy: ([dummy.x, dummy.y], None),
|
lambda dummy: ([dummy.x, dummy.y], None),
|
||||||
lambda xs, _: DummyType(*xs),
|
lambda xs, _: DummyType(*xs),
|
||||||
@ -585,7 +646,7 @@ TreeSpec(tuple, None, [*,
|
|||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "Both to_dumpable_context and from_dumpable_context"
|
ValueError, "Both to_dumpable_context and from_dumpable_context"
|
||||||
):
|
):
|
||||||
py_pytree._register_pytree_node(
|
py_pytree.register_pytree_node(
|
||||||
DummyType,
|
DummyType,
|
||||||
lambda dummy: ([dummy.x, dummy.y], None),
|
lambda dummy: ([dummy.x, dummy.y], None),
|
||||||
lambda xs, _: DummyType(*xs),
|
lambda xs, _: DummyType(*xs),
|
||||||
@ -599,7 +660,7 @@ TreeSpec(tuple, None, [*,
|
|||||||
self.x = x
|
self.x = x
|
||||||
self.y = y
|
self.y = y
|
||||||
|
|
||||||
py_pytree._register_pytree_node(
|
py_pytree.register_pytree_node(
|
||||||
DummyType,
|
DummyType,
|
||||||
lambda dummy: ([dummy.x, dummy.y], None),
|
lambda dummy: ([dummy.x, dummy.y], None),
|
||||||
lambda xs, _: DummyType(*xs),
|
lambda xs, _: DummyType(*xs),
|
||||||
|
@ -63,16 +63,16 @@ def register_dataclass_as_pytree_node(
|
|||||||
flatten_fn: Optional[FlattenFunc] = None,
|
flatten_fn: Optional[FlattenFunc] = None,
|
||||||
unflatten_fn: Optional[UnflattenFunc] = None,
|
unflatten_fn: Optional[UnflattenFunc] = None,
|
||||||
*,
|
*,
|
||||||
|
serialized_type_name: Optional[str] = None,
|
||||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||||
serialized_type_name: Optional[str] = None,
|
|
||||||
return_none_fields: bool = False,
|
return_none_fields: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert dataclasses.is_dataclass(
|
assert dataclasses.is_dataclass(
|
||||||
cls
|
cls
|
||||||
), f"Only dataclasses can be registered with this function: {cls}"
|
), f"Only dataclasses can be registered with this function: {cls}"
|
||||||
|
|
||||||
serialized_type = f"{cls.__module__}.{cls.__name__}"
|
serialized_type = f"{cls.__module__}.{cls.__qualname__}"
|
||||||
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls
|
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls
|
||||||
|
|
||||||
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
|
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
|
||||||
|
@ -29,7 +29,7 @@ from torch._logging import getArtifactLogger
|
|||||||
from torch._subclasses import FakeTensor, FakeTensorMode
|
from torch._subclasses import FakeTensor, FakeTensorMode
|
||||||
from torch._subclasses.fake_tensor import is_fake
|
from torch._subclasses.fake_tensor import is_fake
|
||||||
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
|
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
|
||||||
from torch.fx import immutable_collections, Interpreter
|
from torch.fx import Interpreter
|
||||||
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
|
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
ShapeEnv, is_concrete_int, fx_placeholder_vals, definitely_true, definitely_false, sym_eq
|
ShapeEnv, is_concrete_int, fx_placeholder_vals, definitely_true, definitely_false, sym_eq
|
||||||
@ -95,19 +95,6 @@ OutputType = Enum(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
pytree._register_pytree_node(
|
|
||||||
immutable_collections.immutable_list,
|
|
||||||
lambda x: (list(x), None),
|
|
||||||
lambda x, c: immutable_collections.immutable_list(x),
|
|
||||||
)
|
|
||||||
pytree._register_pytree_node(
|
|
||||||
immutable_collections.immutable_dict,
|
|
||||||
lambda x: (list(x.values()), list(x.keys())),
|
|
||||||
lambda x, c: immutable_collections.immutable_dict(
|
|
||||||
dict(zip(c, x))
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def partial_asdict(obj: Any) -> Any:
|
def partial_asdict(obj: Any) -> Any:
|
||||||
if dataclasses.is_dataclass(obj):
|
if dataclasses.is_dataclass(obj):
|
||||||
return {field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)}
|
return {field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)}
|
||||||
|
@ -49,7 +49,7 @@ CONSTANT_NUMEL_LIMIT = 1
|
|||||||
|
|
||||||
# We currently convert all SymInt to proxies before we use them.
|
# We currently convert all SymInt to proxies before we use them.
|
||||||
# This could plausibly be handled at the Dynamo level.
|
# This could plausibly be handled at the Dynamo level.
|
||||||
pytree._register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs))
|
pytree.register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs))
|
||||||
|
|
||||||
def fake_signature(fn, nargs):
|
def fake_signature(fn, nargs):
|
||||||
"""FX gets confused by varargs, de-confuse it"""
|
"""FX gets confused by varargs, de-confuse it"""
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Any, Dict, Iterable, List, Tuple
|
from typing import Any, Dict, Iterable, List, Tuple
|
||||||
|
|
||||||
from ._compatibility import compatibility
|
from ._compatibility import compatibility
|
||||||
from torch.utils._pytree import Context, _register_pytree_node
|
from torch.utils._pytree import Context, register_pytree_node
|
||||||
|
|
||||||
__all__ = ["immutable_list", "immutable_dict"]
|
__all__ = ["immutable_list", "immutable_dict"]
|
||||||
|
|
||||||
@ -50,5 +50,5 @@ def _immutable_list_unflatten(values: Iterable[Any], context: Context) -> List[A
|
|||||||
return immutable_list(values)
|
return immutable_list(values)
|
||||||
|
|
||||||
|
|
||||||
_register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten)
|
register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten)
|
||||||
_register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten)
|
register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten)
|
||||||
|
@ -40,7 +40,11 @@ class _PyTreeExtensionContext:
|
|||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
for class_type, (flatten_func, unflatten_func) in self._extensions.items():
|
for class_type, (flatten_func, unflatten_func) in self._extensions.items():
|
||||||
pytree._register_pytree_node(class_type, flatten_func, unflatten_func)
|
pytree._private_register_pytree_node(
|
||||||
|
class_type,
|
||||||
|
flatten_func,
|
||||||
|
unflatten_func,
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
@ -93,8 +97,11 @@ class _PyTreeExtensionContext:
|
|||||||
# All 'ModelOutput' subclasses are defined under module 'modeling_outputs'.
|
# All 'ModelOutput' subclasses are defined under module 'modeling_outputs'.
|
||||||
named_model_output_classes = inspect.getmembers(
|
named_model_output_classes = inspect.getmembers(
|
||||||
modeling_outputs,
|
modeling_outputs,
|
||||||
lambda x: inspect.isclass(x)
|
lambda x: (
|
||||||
and issubclass(x, modeling_outputs.ModelOutput),
|
inspect.isclass(x)
|
||||||
|
and issubclass(x, modeling_outputs.ModelOutput)
|
||||||
|
and x is not modeling_outputs.ModelOutput
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, class_type in named_model_output_classes:
|
for _, class_type in named_model_output_classes:
|
||||||
|
@ -13,7 +13,7 @@ def pytree_register_structseq(cls):
|
|||||||
def structseq_unflatten(values, context):
|
def structseq_unflatten(values, context):
|
||||||
return cls(values)
|
return cls(values)
|
||||||
|
|
||||||
torch.utils._pytree._register_pytree_node(cls, structseq_flatten, structseq_unflatten)
|
torch.utils._pytree.register_pytree_node(cls, structseq_flatten, structseq_unflatten)
|
||||||
|
|
||||||
for name in dir(return_types):
|
for name in dir(return_types):
|
||||||
if name.startswith('__'):
|
if name.startswith('__'):
|
||||||
|
@ -13,6 +13,7 @@ collection support for PyTorch APIs.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import warnings
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
@ -26,6 +27,11 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if torch._running_with_deploy():
|
||||||
|
raise ImportError("C++ pytree utilities do not work with torch::deploy.")
|
||||||
|
|
||||||
import optree
|
import optree
|
||||||
from optree import PyTreeSpec # direct import for type annotations
|
from optree import PyTreeSpec # direct import for type annotations
|
||||||
|
|
||||||
@ -35,6 +41,9 @@ __all__ = [
|
|||||||
"Context",
|
"Context",
|
||||||
"FlattenFunc",
|
"FlattenFunc",
|
||||||
"UnflattenFunc",
|
"UnflattenFunc",
|
||||||
|
"DumpableContext",
|
||||||
|
"ToDumpableContextFn",
|
||||||
|
"FromDumpableContextFn",
|
||||||
"TreeSpec",
|
"TreeSpec",
|
||||||
"LeafSpec",
|
"LeafSpec",
|
||||||
"register_pytree_node",
|
"register_pytree_node",
|
||||||
@ -68,6 +77,9 @@ TreeSpec = PyTreeSpec
|
|||||||
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
|
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
|
||||||
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
|
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
|
||||||
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
|
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
|
||||||
|
DumpableContext = Any # Any json dumpable text
|
||||||
|
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
||||||
|
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
||||||
|
|
||||||
|
|
||||||
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
|
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
|
||||||
@ -84,9 +96,11 @@ def register_pytree_node(
|
|||||||
unflatten_fn: UnflattenFunc,
|
unflatten_fn: UnflattenFunc,
|
||||||
*,
|
*,
|
||||||
serialized_type_name: Optional[str] = None,
|
serialized_type_name: Optional[str] = None,
|
||||||
|
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||||
|
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||||
namespace: str = "torch",
|
namespace: str = "torch",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Extend the set of types that are considered internal nodes in pytrees.
|
"""Register a container-like type as pytree node.
|
||||||
|
|
||||||
The ``namespace`` argument is used to avoid collisions that occur when different libraries
|
The ``namespace`` argument is used to avoid collisions that occur when different libraries
|
||||||
register the same Python type with different behaviors. It is recommended to add a unique prefix
|
register the same Python type with different behaviors. It is recommended to add a unique prefix
|
||||||
@ -109,6 +123,13 @@ def register_pytree_node(
|
|||||||
The function should return an instance of ``cls``.
|
The function should return an instance of ``cls``.
|
||||||
serialized_type_name (str, optional): A keyword argument used to specify the fully
|
serialized_type_name (str, optional): A keyword argument used to specify the fully
|
||||||
qualified name used when serializing the tree spec.
|
qualified name used when serializing the tree spec.
|
||||||
|
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
|
||||||
|
to convert the context of the pytree to a custom json dumpable representation. This is
|
||||||
|
used for json serialization, which is being used in :mod:`torch.export` right now.
|
||||||
|
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
|
||||||
|
how to convert the custom json dumpable representation of the context back to the
|
||||||
|
original context. This is used for json deserialization, which is being used in
|
||||||
|
:mod:`torch.export` right now.
|
||||||
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
|
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
|
||||||
type registry. This is used to isolate the registry from other modules that might
|
type registry. This is used to isolate the registry from other modules that might
|
||||||
register a different custom behavior for the same type. (default: :const:`"torch"`)
|
register a different custom behavior for the same type. (default: :const:`"torch"`)
|
||||||
@ -193,24 +214,192 @@ def register_pytree_node(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
from ._pytree import _register_pytree_node
|
_private_register_pytree_node(
|
||||||
|
|
||||||
_register_pytree_node(
|
|
||||||
cls,
|
cls,
|
||||||
flatten_fn,
|
flatten_fn,
|
||||||
unflatten_fn,
|
unflatten_fn,
|
||||||
serialized_type_name=serialized_type_name,
|
serialized_type_name=serialized_type_name,
|
||||||
|
to_dumpable_context=to_dumpable_context,
|
||||||
|
from_dumpable_context=from_dumpable_context,
|
||||||
|
namespace=namespace,
|
||||||
)
|
)
|
||||||
|
|
||||||
optree.register_pytree_node(
|
from . import _pytree as python
|
||||||
|
|
||||||
|
python._private_register_pytree_node(
|
||||||
cls,
|
cls,
|
||||||
flatten_fn,
|
flatten_fn,
|
||||||
_reverse_args(unflatten_fn),
|
unflatten_fn,
|
||||||
|
serialized_type_name=serialized_type_name,
|
||||||
|
to_dumpable_context=to_dumpable_context,
|
||||||
|
from_dumpable_context=from_dumpable_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_pytree_node(
|
||||||
|
cls: Type[Any],
|
||||||
|
flatten_fn: FlattenFunc,
|
||||||
|
unflatten_fn: UnflattenFunc,
|
||||||
|
*,
|
||||||
|
serialized_type_name: Optional[str] = None,
|
||||||
|
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||||
|
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||||
|
namespace: str = "torch",
|
||||||
|
) -> None:
|
||||||
|
"""Register a container-like type as pytree node for the C++ pytree only.
|
||||||
|
|
||||||
|
The ``namespace`` argument is used to avoid collisions that occur when different libraries
|
||||||
|
register the same Python type with different behaviors. It is recommended to add a unique prefix
|
||||||
|
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
|
||||||
|
the same class in different namespaces for different use cases.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
|
||||||
|
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
|
||||||
|
prevent accidental collisions between different libraries that may register the same type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls (type): A Python type to treat as an internal pytree node.
|
||||||
|
flatten_fn (callable): A function to be used during flattening, taking an instance of
|
||||||
|
``cls`` and returning a pair, with (1) an iterable for the children to be flattened
|
||||||
|
recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
|
||||||
|
passed to the ``unflatten_fn``.
|
||||||
|
unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
|
||||||
|
returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
|
||||||
|
The function should return an instance of ``cls``.
|
||||||
|
serialized_type_name (str, optional): A keyword argument used to specify the fully
|
||||||
|
qualified name used when serializing the tree spec.
|
||||||
|
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
|
||||||
|
to convert the context of the pytree to a custom json dumpable representation. This is
|
||||||
|
used for json serialization, which is being used in :mod:`torch.export` right now.
|
||||||
|
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
|
||||||
|
how to convert the custom json dumpable representation of the context back to the
|
||||||
|
original context. This is used for json deserialization, which is being used in
|
||||||
|
:mod:`torch.export` right now.
|
||||||
|
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
|
||||||
|
type registry. This is used to isolate the registry from other modules that might
|
||||||
|
register a different custom behavior for the same type. (default: :const:`"torch"`)
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> # xdoctest: +SKIP
|
||||||
|
>>> # Registry a Python type with lambda functions
|
||||||
|
>>> register_pytree_node(
|
||||||
|
... set,
|
||||||
|
... lambda s: (sorted(s), None, None),
|
||||||
|
... lambda children, _: set(children),
|
||||||
|
... namespace='set',
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # xdoctest: +SKIP
|
||||||
|
>>> # Register a Python type into a namespace
|
||||||
|
>>> import torch
|
||||||
|
>>> register_pytree_node(
|
||||||
|
... torch.Tensor,
|
||||||
|
... flatten_func=lambda tensor: (
|
||||||
|
... (tensor.cpu().detach().numpy(),),
|
||||||
|
... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
|
||||||
|
... ),
|
||||||
|
... unflatten_func=lambda children, metadata: torch.tensor(children[0], **metadata),
|
||||||
|
... namespace='torch2numpy',
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||||
|
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
|
||||||
|
>>> tree
|
||||||
|
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
|
||||||
|
|
||||||
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||||
|
>>> # Flatten without specifying the namespace
|
||||||
|
>>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes # xdoctest: +SKIP
|
||||||
|
([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
|
||||||
|
|
||||||
|
>>> # xdoctest: +SKIP
|
||||||
|
>>> # Flatten with the namespace
|
||||||
|
>>> tree_flatten(tree, namespace='torch2numpy') # xdoctest: +SKIP
|
||||||
|
(
|
||||||
|
[array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
|
||||||
|
PyTreeSpec(
|
||||||
|
{
|
||||||
|
'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]),
|
||||||
|
'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*])
|
||||||
|
},
|
||||||
|
namespace='torch2numpy'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
>>> # xdoctest: +SKIP
|
||||||
|
>>> # Register the same type with a different namespace for different behaviors
|
||||||
|
>>> def tensor2flatparam(tensor):
|
||||||
|
... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
|
||||||
|
...
|
||||||
|
>>> def flatparam2tensor(children, metadata):
|
||||||
|
... return children[0].reshape(metadata)
|
||||||
|
...
|
||||||
|
>>> register_pytree_node(
|
||||||
|
... torch.Tensor,
|
||||||
|
... flatten_func=tensor2flatparam,
|
||||||
|
... unflatten_func=flatparam2tensor,
|
||||||
|
... namespace='tensor2flatparam',
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # xdoctest: +SKIP
|
||||||
|
>>> # Flatten with the new namespace
|
||||||
|
>>> tree_flatten(tree, namespace='tensor2flatparam') # xdoctest: +SKIP
|
||||||
|
(
|
||||||
|
[
|
||||||
|
Parameter containing: tensor([0., 0.], requires_grad=True),
|
||||||
|
Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True)
|
||||||
|
],
|
||||||
|
PyTreeSpec(
|
||||||
|
{
|
||||||
|
'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]),
|
||||||
|
'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*])
|
||||||
|
},
|
||||||
|
namespace='tensor2flatparam'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
"torch.utils._cxx_pytree._register_pytree_node is deprecated. "
|
||||||
|
"Please use torch.utils._cxx_pytree.register_pytree_node instead.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
_private_register_pytree_node(
|
||||||
|
cls,
|
||||||
|
flatten_fn,
|
||||||
|
unflatten_fn,
|
||||||
|
serialized_type_name=serialized_type_name,
|
||||||
|
to_dumpable_context=to_dumpable_context,
|
||||||
|
from_dumpable_context=from_dumpable_context,
|
||||||
namespace=namespace,
|
namespace=namespace,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_pytree_node = register_pytree_node
|
def _private_register_pytree_node(
|
||||||
|
cls: Type[Any],
|
||||||
|
flatten_fn: FlattenFunc,
|
||||||
|
unflatten_fn: UnflattenFunc,
|
||||||
|
*,
|
||||||
|
serialized_type_name: Optional[str] = None,
|
||||||
|
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||||
|
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||||
|
namespace: str = "torch",
|
||||||
|
) -> None:
|
||||||
|
"""This is an internal function that is used to register a pytree node type
|
||||||
|
for the C++ pytree only. End-users should use :func:`register_pytree_node`
|
||||||
|
instead.
|
||||||
|
"""
|
||||||
|
# TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
|
||||||
|
# PyStructSequence types
|
||||||
|
if not optree.is_structseq_class(cls):
|
||||||
|
optree.register_pytree_node(
|
||||||
|
cls,
|
||||||
|
flatten_fn,
|
||||||
|
_reverse_args(unflatten_fn),
|
||||||
|
namespace=namespace,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def tree_flatten(
|
def tree_flatten(
|
||||||
|
@ -17,6 +17,7 @@ To improve the performance we can move parts of the implementation to C++.
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
from collections import deque, namedtuple, OrderedDict
|
from collections import deque, namedtuple, OrderedDict
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -99,6 +100,7 @@ class NodeDef(NamedTuple):
|
|||||||
unflatten_fn: UnflattenFunc
|
unflatten_fn: UnflattenFunc
|
||||||
|
|
||||||
|
|
||||||
|
_NODE_REGISTRY_LOCK = threading.Lock()
|
||||||
SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
|
SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
|
||||||
|
|
||||||
|
|
||||||
@ -120,18 +122,17 @@ SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {}
|
|||||||
SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {}
|
SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {}
|
||||||
|
|
||||||
|
|
||||||
def _register_pytree_node(
|
def register_pytree_node(
|
||||||
cls: Any,
|
cls: Any,
|
||||||
flatten_fn: FlattenFunc,
|
flatten_fn: FlattenFunc,
|
||||||
unflatten_fn: UnflattenFunc,
|
unflatten_fn: UnflattenFunc,
|
||||||
to_str_fn: Optional[ToStrFunc] = None, # deprecated
|
|
||||||
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated
|
|
||||||
*,
|
*,
|
||||||
serialized_type_name: Optional[str] = None,
|
serialized_type_name: Optional[str] = None,
|
||||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Register a container-like type as pytree node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cls: the type to register
|
cls: the type to register
|
||||||
flatten_fn: A callable that takes a pytree and returns a flattened
|
flatten_fn: A callable that takes a pytree and returns a flattened
|
||||||
@ -150,39 +151,132 @@ def _register_pytree_node(
|
|||||||
back to the original context. This is used for json deserialization,
|
back to the original context. This is used for json deserialization,
|
||||||
which is being used in torch.export right now.
|
which is being used in torch.export right now.
|
||||||
"""
|
"""
|
||||||
|
with _NODE_REGISTRY_LOCK:
|
||||||
|
if cls in SUPPORTED_NODES:
|
||||||
|
raise ValueError(f"{cls} is already registered as pytree node.")
|
||||||
|
|
||||||
|
_private_register_pytree_node(
|
||||||
|
cls,
|
||||||
|
flatten_fn,
|
||||||
|
unflatten_fn,
|
||||||
|
serialized_type_name=serialized_type_name,
|
||||||
|
to_dumpable_context=to_dumpable_context,
|
||||||
|
from_dumpable_context=from_dumpable_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from . import _cxx_pytree as cxx
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
cxx._private_register_pytree_node(
|
||||||
|
cls,
|
||||||
|
flatten_fn,
|
||||||
|
unflatten_fn,
|
||||||
|
serialized_type_name=serialized_type_name,
|
||||||
|
to_dumpable_context=to_dumpable_context,
|
||||||
|
from_dumpable_context=from_dumpable_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_pytree_node(
|
||||||
|
cls: Any,
|
||||||
|
flatten_fn: FlattenFunc,
|
||||||
|
unflatten_fn: UnflattenFunc,
|
||||||
|
to_str_fn: Optional[ToStrFunc] = None, # deprecated
|
||||||
|
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated
|
||||||
|
*,
|
||||||
|
serialized_type_name: Optional[str] = None,
|
||||||
|
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||||
|
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Register a container-like type as pytree node for the Python pytree only.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: the type to register
|
||||||
|
flatten_fn: A callable that takes a pytree and returns a flattened
|
||||||
|
representation of the pytree and additional context to represent the
|
||||||
|
flattened pytree.
|
||||||
|
unflatten_fn: A callable that takes a flattened version of the pytree,
|
||||||
|
additional context, and returns an unflattened pytree.
|
||||||
|
serialized_type_name: A keyword argument used to specify the fully qualified
|
||||||
|
name used when serializing the tree spec.
|
||||||
|
to_dumpable_context: An optional keyword argument to custom specify how
|
||||||
|
to convert the context of the pytree to a custom json dumpable
|
||||||
|
representation. This is used for json serialization, which is being
|
||||||
|
used in torch.export right now.
|
||||||
|
from_dumpable_context: An optional keyword argument to custom specify how
|
||||||
|
to convert the custom json dumpable representation of the context
|
||||||
|
back to the original context. This is used for json deserialization,
|
||||||
|
which is being used in torch.export right now.
|
||||||
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
"torch.utils._pytree._register_pytree_node is deprecated. "
|
||||||
|
"Please use torch.utils._pytree.register_pytree_node instead.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
if to_str_fn is not None or maybe_from_str_fn is not None:
|
if to_str_fn is not None or maybe_from_str_fn is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"to_str_fn and maybe_from_str_fn is deprecated. "
|
"to_str_fn and maybe_from_str_fn is deprecated. "
|
||||||
"Please use to_dumpable_context and from_dumpable_context instead."
|
"Please use to_dumpable_context and from_dumpable_context instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
node_def = NodeDef(
|
_private_register_pytree_node(
|
||||||
cls,
|
cls,
|
||||||
flatten_fn,
|
flatten_fn,
|
||||||
unflatten_fn,
|
unflatten_fn,
|
||||||
|
serialized_type_name=serialized_type_name,
|
||||||
|
to_dumpable_context=to_dumpable_context,
|
||||||
|
from_dumpable_context=from_dumpable_context,
|
||||||
)
|
)
|
||||||
SUPPORTED_NODES[cls] = node_def
|
|
||||||
|
|
||||||
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
|
|
||||||
raise ValueError(
|
def _private_register_pytree_node(
|
||||||
f"Both to_dumpable_context and from_dumpable_context for {cls} must "
|
cls: Any,
|
||||||
"be None or registered."
|
flatten_fn: FlattenFunc,
|
||||||
|
unflatten_fn: UnflattenFunc,
|
||||||
|
*,
|
||||||
|
serialized_type_name: Optional[str] = None,
|
||||||
|
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||||
|
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||||
|
) -> None:
|
||||||
|
"""This is an internal function that is used to register a pytree node type
|
||||||
|
for the Python pytree only. End-users should use :func:`register_pytree_node`
|
||||||
|
instead.
|
||||||
|
"""
|
||||||
|
with _NODE_REGISTRY_LOCK:
|
||||||
|
if cls in SUPPORTED_NODES:
|
||||||
|
# TODO: change this warning to an error after OSS/internal stabilize
|
||||||
|
warnings.warn(
|
||||||
|
f"{cls} is already registered as pytree node. "
|
||||||
|
"Overwriting the previous registration.",
|
||||||
|
)
|
||||||
|
|
||||||
|
node_def = NodeDef(
|
||||||
|
cls,
|
||||||
|
flatten_fn,
|
||||||
|
unflatten_fn,
|
||||||
)
|
)
|
||||||
|
SUPPORTED_NODES[cls] = node_def
|
||||||
|
|
||||||
if serialized_type_name is None:
|
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
|
||||||
serialized_type_name = f"{cls.__module__}.{cls.__name__}"
|
raise ValueError(
|
||||||
|
f"Both to_dumpable_context and from_dumpable_context for {cls} must "
|
||||||
|
"be None or registered."
|
||||||
|
)
|
||||||
|
|
||||||
serialize_node_def = _SerializeNodeDef(
|
if serialized_type_name is None:
|
||||||
cls,
|
serialized_type_name = f"{cls.__module__}.{cls.__qualname__}"
|
||||||
serialized_type_name,
|
|
||||||
to_dumpable_context,
|
|
||||||
from_dumpable_context,
|
|
||||||
)
|
|
||||||
SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def
|
|
||||||
SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls
|
|
||||||
|
|
||||||
|
serialize_node_def = _SerializeNodeDef(
|
||||||
register_pytree_node = _register_pytree_node
|
cls,
|
||||||
|
serialized_type_name,
|
||||||
|
to_dumpable_context,
|
||||||
|
from_dumpable_context,
|
||||||
|
)
|
||||||
|
SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def
|
||||||
|
SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls
|
||||||
|
|
||||||
|
|
||||||
def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
|
def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
|
||||||
@ -243,25 +337,25 @@ def _odict_unflatten(
|
|||||||
return OrderedDict((key, value) for key, value in zip(context, values))
|
return OrderedDict((key, value) for key, value in zip(context, values))
|
||||||
|
|
||||||
|
|
||||||
_register_pytree_node(
|
_private_register_pytree_node(
|
||||||
dict,
|
dict,
|
||||||
_dict_flatten,
|
_dict_flatten,
|
||||||
_dict_unflatten,
|
_dict_unflatten,
|
||||||
serialized_type_name="builtins.dict",
|
serialized_type_name="builtins.dict",
|
||||||
)
|
)
|
||||||
_register_pytree_node(
|
_private_register_pytree_node(
|
||||||
list,
|
list,
|
||||||
_list_flatten,
|
_list_flatten,
|
||||||
_list_unflatten,
|
_list_unflatten,
|
||||||
serialized_type_name="builtins.list",
|
serialized_type_name="builtins.list",
|
||||||
)
|
)
|
||||||
_register_pytree_node(
|
_private_register_pytree_node(
|
||||||
tuple,
|
tuple,
|
||||||
_tuple_flatten,
|
_tuple_flatten,
|
||||||
_tuple_unflatten,
|
_tuple_unflatten,
|
||||||
serialized_type_name="builtins.tuple",
|
serialized_type_name="builtins.tuple",
|
||||||
)
|
)
|
||||||
_register_pytree_node(
|
_private_register_pytree_node(
|
||||||
namedtuple,
|
namedtuple,
|
||||||
_namedtuple_flatten,
|
_namedtuple_flatten,
|
||||||
_namedtuple_unflatten,
|
_namedtuple_unflatten,
|
||||||
@ -269,7 +363,7 @@ _register_pytree_node(
|
|||||||
from_dumpable_context=_namedtuple_deserialize,
|
from_dumpable_context=_namedtuple_deserialize,
|
||||||
serialized_type_name="collections.namedtuple",
|
serialized_type_name="collections.namedtuple",
|
||||||
)
|
)
|
||||||
_register_pytree_node(
|
_private_register_pytree_node(
|
||||||
OrderedDict,
|
OrderedDict,
|
||||||
_odict_flatten,
|
_odict_flatten,
|
||||||
_odict_unflatten,
|
_odict_unflatten,
|
||||||
@ -729,7 +823,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
|
|||||||
|
|
||||||
if treespec.type not in SUPPORTED_SERIALIZED_TYPES:
|
if treespec.type not in SUPPORTED_SERIALIZED_TYPES:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Serializing {treespec.type} in pytree is not registered."
|
f"Serializing {treespec.type} in pytree is not registered.",
|
||||||
)
|
)
|
||||||
|
|
||||||
serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type]
|
serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type]
|
||||||
|
Reference in New Issue
Block a user