Compare commits

...

55 Commits

Author SHA1 Message Date
9bf357485d Update
[ghstack-poisoned]
2025-11-14 00:01:02 +08:00
82cee53b5b Update (base update)
[ghstack-poisoned]
2025-11-14 00:01:02 +08:00
fcf64c5430 Update
[ghstack-poisoned]
2025-10-29 20:59:32 +08:00
e919487cee Update (base update)
[ghstack-poisoned]
2025-10-29 20:59:32 +08:00
663a6641b8 Update
[ghstack-poisoned]
2025-10-11 21:38:33 +08:00
c81fff72c4 Update (base update)
[ghstack-poisoned]
2025-10-11 21:38:33 +08:00
0b0d14e369 Update
[ghstack-poisoned]
2025-10-08 22:37:37 +08:00
8e0faf10d7 Update (base update)
[ghstack-poisoned]
2025-10-08 22:37:37 +08:00
55952b9410 Update
[ghstack-poisoned]
2025-09-19 18:12:33 +08:00
3333f9a7fe Update (base update)
[ghstack-poisoned]
2025-09-19 18:12:33 +08:00
35dc5e9053 Update
[ghstack-poisoned]
2025-09-06 11:35:55 +08:00
a37b672868 Update (base update)
[ghstack-poisoned]
2025-09-06 11:35:55 +08:00
5ce62d8ca4 Update
[ghstack-poisoned]
2025-08-17 16:24:30 +08:00
8a44116231 Update (base update)
[ghstack-poisoned]
2025-08-17 16:24:29 +08:00
27b3790f84 Update
[ghstack-poisoned]
2025-08-09 02:57:36 +08:00
4c87eb03ad Update (base update)
[ghstack-poisoned]
2025-08-09 02:57:36 +08:00
6436b79440 Update
[ghstack-poisoned]
2025-07-25 20:01:36 +08:00
cd3da610db Update (base update)
[ghstack-poisoned]
2025-07-25 20:01:36 +08:00
e4b79a3417 Update
[ghstack-poisoned]
2025-07-09 18:57:05 +08:00
fd46ad836e Update (base update)
[ghstack-poisoned]
2025-07-09 18:57:05 +08:00
f318471eae Update
[ghstack-poisoned]
2025-07-03 21:41:56 +08:00
aeef130c4a Update (base update)
[ghstack-poisoned]
2025-07-03 21:41:56 +08:00
5d3e10c781 Update
[ghstack-poisoned]
2025-06-18 23:16:07 +08:00
110120a627 Update (base update)
[ghstack-poisoned]
2025-06-18 23:16:07 +08:00
0564cb9ed9 Update
[ghstack-poisoned]
2025-06-07 18:03:47 +08:00
2a60e06b41 Update (base update)
[ghstack-poisoned]
2025-06-07 18:03:47 +08:00
bc6876ef87 Update
[ghstack-poisoned]
2025-05-31 21:52:37 +08:00
f5d96573a9 Update (base update)
[ghstack-poisoned]
2025-05-31 21:52:37 +08:00
fd5943385b Update
[ghstack-poisoned]
2025-05-28 20:46:34 +08:00
5fd24037de Update (base update)
[ghstack-poisoned]
2025-05-28 20:46:34 +08:00
26f96f1783 Update
[ghstack-poisoned]
2025-05-16 11:40:46 +08:00
ec8ad8f509 Update (base update)
[ghstack-poisoned]
2025-05-16 11:40:46 +08:00
8849c86bad Update
[ghstack-poisoned]
2025-05-14 20:38:36 +08:00
9eb1a4282c Update (base update)
[ghstack-poisoned]
2025-05-14 20:38:36 +08:00
446d76b993 Update
[ghstack-poisoned]
2025-05-08 21:24:23 +08:00
9e23b9fc4f Update (base update)
[ghstack-poisoned]
2025-05-08 21:24:23 +08:00
7768d54131 Update
[ghstack-poisoned]
2025-05-02 23:45:19 +08:00
24a0401d9f Update (base update)
[ghstack-poisoned]
2025-05-02 20:28:06 +08:00
8421a29a43 Update
[ghstack-poisoned]
2025-05-02 20:28:06 +08:00
fc31b07c88 Update (base update)
[ghstack-poisoned]
2025-04-09 22:47:00 +08:00
6af2396424 Update
[ghstack-poisoned]
2025-04-09 22:47:00 +08:00
248dcc0d0f Update (base update)
[ghstack-poisoned]
2025-04-01 23:32:40 +08:00
1adfe613e5 Update
[ghstack-poisoned]
2025-04-01 23:32:40 +08:00
a3e6c756f4 Update (base update)
[ghstack-poisoned]
2025-03-31 21:22:11 +08:00
2105991fa0 Update
[ghstack-poisoned]
2025-03-31 21:22:11 +08:00
cce45b1061 Update (base update)
[ghstack-poisoned]
2025-03-14 12:48:20 +08:00
fd94b3884b Update
[ghstack-poisoned]
2025-03-14 12:48:20 +08:00
c382241afd Update
[ghstack-poisoned]
2025-03-14 03:57:51 +08:00
062bbcd4b0 Update (base update)
[ghstack-poisoned]
2025-03-14 02:39:35 +08:00
7286099b05 Update
[ghstack-poisoned]
2025-03-14 02:39:35 +08:00
6847701526 Update (base update)
[ghstack-poisoned]
2025-03-13 04:54:01 +08:00
d3b3d16438 Update
[ghstack-poisoned]
2025-03-13 04:54:01 +08:00
5d89057ba9 Update
[ghstack-poisoned]
2025-03-05 04:58:51 +08:00
8634a7df00 Update (base update)
[ghstack-poisoned]
2025-03-05 04:53:51 +08:00
49ef07a8d1 Update
[ghstack-poisoned]
2025-03-05 04:53:51 +08:00
10 changed files with 241 additions and 114 deletions

View File

@ -6230,9 +6230,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
register_pytree_node(
Box,
lambda box: ([box.content], None), # flatten_fn
lambda contents, _context: Box(*contents), # unflatten_fn
flatten_with_keys_fn=None, # unflatten_fn
lambda box: ([box.content], None), # flatten_func
lambda contents, _: Box(*contents), # unflatten_func
flatten_with_keys_func=None, # flatten_with_keys_func
serialized_type_name="test_no_suggested_fixes_for_data_dependent_errors.Box",
)

View File

@ -1347,9 +1347,9 @@ if "optree" in sys.modules:
python_pytree.register_pytree_node(
ACustomPytree,
flatten_fn=lambda f: ([f.x, f.y], f.z),
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
flatten_func=lambda f: ([f.x, f.y], f.z),
unflatten_func=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
flatten_with_keys_func=lambda f: ((("x", f.x), ("y", f.y)), f.z),
)
from_two_trees = python_pytree.tree_map_with_path(
lambda kp, a, b: a + b, tree1, tree2
@ -1379,9 +1379,9 @@ if "optree" in sys.modules:
python_pytree.register_pytree_node(
ACustomPytree,
flatten_fn=lambda f: ([f.x, f.y], f.z),
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
flatten_func=lambda f: ([f.x, f.y], f.z),
unflatten_func=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
flatten_with_keys_func=lambda f: ((("x", f.x), ("y", f.y)), f.z),
)
SOME_PYTREES = [
@ -1411,9 +1411,9 @@ if "optree" in sys.modules:
python_pytree.register_pytree_node(
ACustomPytree,
flatten_fn=lambda f: ([f.x, f.y], f.z),
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
flatten_func=lambda f: ([f.x, f.y], f.z),
unflatten_func=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
flatten_with_keys_func=lambda f: ((("x", f.x), ("y", f.y)), f.z),
)
SOME_PYTREES = [

View File

@ -36,7 +36,7 @@ from torch.fx._pytree import (
)
from torch.utils._pytree import (
_deregister_pytree_node,
_register_pytree_node,
_private_register_pytree_node,
Context,
FlattenFunc,
FromDumpableContextFn,
@ -481,8 +481,8 @@ def _check_input_constraints_for_graph(
def register_dataclass_as_pytree_node(
cls: type[Any],
flatten_fn: Optional[FlattenFunc] = None,
unflatten_fn: Optional[UnflattenFunc] = None,
flatten_func: Optional[FlattenFunc] = None,
unflatten_func: Optional[UnflattenFunc] = None,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
@ -494,7 +494,7 @@ def register_dataclass_as_pytree_node(
)
@torch._dynamo.dont_skip_tracing
def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]:
def default_flatten_func(obj: Any) -> tuple[list[Any], Context]:
flattened = []
flat_names = []
none_names = []
@ -508,17 +508,19 @@ def register_dataclass_as_pytree_node(
return flattened, [flat_names, none_names]
@torch._dynamo.dont_skip_tracing
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
def default_unflatten_func(values: Iterable[Any], context: Context) -> Any:
flat_names, none_names = context
return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
@torch._dynamo.dont_skip_tracing
def default_flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
flattened, (flat_names, _none_names) = flatten_fn(obj) # type: ignore[misc]
def default_flatten_func_with_keys(obj: Any) -> tuple[list[Any], Context]:
flattened, (flat_names, _none_names) = flatten_func(obj) # type: ignore[misc]
return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names
flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn
if flatten_func is None:
flatten_func = default_flatten_func
if unflatten_func is None:
unflatten_func = default_unflatten_func
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
raise ValueError(
@ -526,12 +528,12 @@ def register_dataclass_as_pytree_node(
"be None or registered."
)
_register_pytree_node(
_private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
flatten_func,
unflatten_func,
serialized_type_name=serialized_type_name,
flatten_with_keys_fn=default_flatten_fn_with_keys,
flatten_with_keys_func=default_flatten_func_with_keys,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
)
@ -1480,7 +1482,7 @@ def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
def __deepcopy__(self, memo):
return PrototypeModule(self())
def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]:
def default_flatten_func(obj: Any) -> tuple[list[Any], Context]:
named_parameters = dict(obj.named_parameters())
named_buffers = dict(obj.named_buffers())
params_buffers = {**named_parameters, **named_buffers}
@ -1489,13 +1491,13 @@ def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
PrototypeModule(obj),
]
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
def default_unflatten_func(values: Iterable[Any], context: Context) -> Any:
flat_names, ref = context
if ref is None or ref() is None:
raise RuntimeError("Module has been garbage collected")
obj = ref()
assert flatten_fn is not None
flattened, _ = flatten_fn(obj)
assert flatten_func is not None
flattened, _ = flatten_func(obj)
# NOTE: This helper function will replicate an nn.Module in the exactly same
# structure to be used together with _reparameterize_module. This will
@ -1517,15 +1519,15 @@ def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
ret = obj
return ret
def default_flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
flattened, [flat_names, *args] = flatten_fn(obj) # type: ignore[misc]
def default_flatten_func_with_keys(obj: Any) -> tuple[list[Any], Context]:
flattened, [flat_names, *args] = flatten_func(obj) # type: ignore[misc]
return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], [
flat_names,
*args,
]
flatten_fn = default_flatten_fn
unflatten_fn = default_unflatten_fn
flatten_func = default_flatten_func
unflatten_func = default_unflatten_func
serialized_type_name = cls.__module__ + "." + cls.__qualname__
@ -1538,18 +1540,18 @@ def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
s[1] = PrototypeModule(torch.nn.Module())
return s
_register_pytree_node(
_private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
flatten_func,
unflatten_func,
serialized_type_name=serialized_type_name,
flatten_with_keys_fn=default_flatten_fn_with_keys,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
flatten_with_keys_func=default_flatten_func_with_keys,
)
def default_flatten_fn_spec(obj, spec) -> list[Any]:
flats, context = flatten_fn(obj)
flats, context = flatten_func(obj)
assert context == spec.context
return flats

View File

@ -623,7 +623,7 @@ def _tree_map_with_path(
# in which case flatten and recurse
return tree_map_with_path(
f,
SUPPORTED_NODES[typ].flatten_fn(t)[0],
SUPPORTED_NODES[typ].flatten_func(t)[0],
*dynamic_shapes,
is_leaf=is_leaf,
)

View File

@ -20,11 +20,11 @@ _V = TypeVar("_V")
def register_pytree_flatten_spec(
cls: type[Any],
flatten_fn_spec: FlattenFuncSpec,
flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
flatten_func_spec: FlattenFuncSpec,
flatten_func_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
) -> None:
SUPPORTED_NODES[cls] = flatten_fn_spec
SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec
SUPPORTED_NODES[cls] = flatten_func_spec
SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_func_exact_match_spec
def _deregister_pytree_flatten_spec(
@ -46,8 +46,8 @@ def tree_flatten_spec(
# as export serializes the pytree separately.
# Will remove it in follow up PR.
if spec.type in SUPPORTED_NODES:
flatten_fn_spec = SUPPORTED_NODES[spec.type]
child_pytrees = flatten_fn_spec(pytree, spec)
flatten_func_spec = SUPPORTED_NODES[spec.type]
child_pytrees = flatten_func_spec(pytree, spec)
result = []
for child, child_spec in zip(child_pytrees, spec.children()):
flat = tree_flatten_spec(child, child_spec)

View File

@ -129,7 +129,7 @@ pytree.register_pytree_node(
lambda xs: (list(xs), None),
lambda xs, _: tuple(xs),
# pyrefly: ignore [bad-argument-type]
flatten_with_keys_fn=lambda xs: (
flatten_with_keys_func=lambda xs: (
[(pytree.SequenceKey(i), x) for i, x in enumerate(xs)],
None,
),

View File

@ -111,12 +111,12 @@ register_pytree_node(
_immutable_list_flatten,
_immutable_list_unflatten,
serialized_type_name="torch.fx.immutable_collections.immutable_list",
flatten_with_keys_fn=_list_flatten_with_keys,
flatten_with_keys_func=_list_flatten_with_keys,
)
register_pytree_node(
immutable_dict,
_immutable_dict_flatten,
_immutable_dict_unflatten,
serialized_type_name="torch.fx.immutable_collections.immutable_dict",
flatten_with_keys_fn=_dict_flatten_with_keys,
flatten_with_keys_func=_dict_flatten_with_keys,
)

View File

@ -27,7 +27,7 @@ def pytree_register_structseq(cls):
cls,
structseq_flatten,
structseq_unflatten,
flatten_with_keys_fn=structseq_flatten_with_keys,
flatten_with_keys_func=structseq_flatten_with_keys,
)

View File

@ -14,6 +14,7 @@ collection support for PyTorch APIs.
import functools
import types
import warnings
from collections.abc import Callable, Iterable, Mapping
from typing import Any, overload, TypeAlias, TypeVar, Union
from typing_extensions import deprecated, Self, TypeIs
@ -125,24 +126,28 @@ def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
def register_pytree_node(
cls: type[Any],
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
flatten_func: FlattenFunc = None, # type: ignore[assignment] # the type is guaranteed
unflatten_func: UnflattenFunc = None, # type: ignore[assignment] # the type is guaranteed
*,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
flatten_with_keys_func: FlattenWithKeysFunc | None = None,
# TODO(XuehaiPan): remove these deprecated arguments and remove the type ignore above
flatten_fn: FlattenFunc | None = None,
unflatten_fn: UnflattenFunc | None = None,
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
) -> None:
"""Register a container-like type as pytree node.
Args:
cls (type): A Python type to treat as an internal pytree node.
flatten_fn (callable): A function to be used during flattening, taking an instance of
flatten_func (callable): A function to be used during flattening, taking an instance of
``cls`` and returning a pair, with (1) an iterable for the children to be flattened
recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
passed to the ``unflatten_fn``.
unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
passed to the ``unflatten_func``.
unflatten_func (callable): A function taking two arguments: the auxiliary data that was
returned by ``flatten_func`` and stored in the treespec, and the unflattened children.
The function should return an instance of ``cls``.
serialized_type_name (str, optional): A keyword argument used to specify the fully
qualified name used when serializing the tree spec.
@ -164,13 +169,57 @@ def register_pytree_node(
... lambda children, _: set(children),
... )
"""
if flatten_with_keys_fn is not None:
if flatten_with_keys_func is not None:
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
if (flatten_func is None) != (unflatten_func is None):
raise ValueError(
"Both flatten_func and unflatten_func must be provided together."
)
if (flatten_fn is None) != (unflatten_fn is None):
raise ValueError("Both flatten_fn and unflatten_fn must be provided together.")
if flatten_func is None and flatten_fn is None:
raise TypeError(
"Missing required argument: 'flatten_func' and 'unflatten_func'."
)
if flatten_func is not None and flatten_fn is not None:
raise ValueError(
"Either (flatten_func, unflatten_func) or (flatten_fn, unflatten_fn) "
"should be provided, not both."
)
if flatten_with_keys_func is not None and flatten_with_keys_fn is not None:
raise ValueError(
"Either flatten_with_keys_func or flatten_with_keys_fn "
"should be provided, not both."
)
if flatten_fn is not None:
warnings.warn(
"The `flatten_fn` and `unflatten_fn` arguments are deprecated. "
"Use `flatten_func` and `unflatten_func` instead.",
category=FutureWarning,
statcklevel=2,
)
(
(flatten_func, unflatten_func),
(flatten_fn, unflatten_fn),
) = (
(flatten_fn, unflatten_fn),
(None, None),
)
if flatten_with_keys_fn is not None:
warnings.warn(
"The `flatten_with_keys_fn` argument is deprecated. "
"Use `flatten_with_keys_func` instead.",
category=FutureWarning,
stacklevel=2,
)
flatten_with_keys_func, flatten_with_keys_fn = flatten_with_keys_fn, None
_private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
flatten_func,
unflatten_func,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
@ -178,8 +227,8 @@ def register_pytree_node(
python_pytree._private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
flatten_func,
unflatten_func,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
@ -244,8 +293,8 @@ def _register_pytree_node(
def _private_register_pytree_node(
cls: type[Any],
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
flatten_func: FlattenFunc,
unflatten_func: UnflattenFunc,
*,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
@ -260,8 +309,8 @@ def _private_register_pytree_node(
if not optree.is_structseq_class(cls):
optree.register_pytree_node(
cls,
flatten_fn,
_reverse_args(unflatten_fn),
flatten_func,
_reverse_args(unflatten_func),
namespace="torch",
)

View File

@ -132,19 +132,46 @@ FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]]
# A NodeDef holds two callables:
# - flatten_fn should take the collection and return a flat list of values.
# - flatten_func should take the collection and return a flat list of values.
# It can also return some context that is used in reconstructing the
# collection.
# - unflatten_fn should take a flat list of values and some context
# (returned by flatten_fn). It returns the collection by reconstructing
# - unflatten_func should take a flat list of values and some context
# (returned by flatten_func). It returns the collection by reconstructing
# it from the list and the context.
# - flatten_with_keys_fn, which is a callable that takes a
# - flatten_with_keys_func, which is a callable that takes a
# pytree and returns a list of (keypath, value) pairs and a context.
class NodeDef(NamedTuple):
type: type[Any]
flatten_fn: FlattenFunc
unflatten_fn: UnflattenFunc
flatten_with_keys_fn: FlattenWithKeysFunc | None
flatten_func: FlattenFunc
unflatten_func: UnflattenFunc
flatten_with_keys_func: FlattenWithKeysFunc | None = None
@property
@deprecated(
"`NodeDef.flatten_fn` is deprecated. "
"Please use `NodeDef.flatten_func` instead.",
category=FutureWarning,
)
def flatten_fn(self) -> FlattenFunc:
return self.flatten_func
@property
@deprecated(
"`NodeDef.unflatten_fn` is deprecated. "
"Please use `NodeDef.unflatten_func` instead.",
category=FutureWarning,
)
def unflatten_fn(self) -> UnflattenFunc:
return self.unflatten_func
@property
@deprecated(
"`NodeDef.flatten_with_keys_fn` is deprecated. "
"Please use `NodeDef.flatten_with_keys_func` instead.",
category=FutureWarning,
)
def flatten_with_keys_fn(self) -> FlattenWithKeysFunc | None:
return self.flatten_with_keys_func
_NODE_REGISTRY_LOCK = threading.RLock()
@ -195,12 +222,16 @@ _cxx_pytree_pending_imports: list[Any] = []
def register_pytree_node(
cls: type[Any],
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
flatten_func: FlattenFunc = None, # type: ignore[assignment] # the type is guaranteed
unflatten_func: UnflattenFunc = None, # type: ignore[assignment] # the type is guaranteed
*,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
flatten_with_keys_func: FlattenWithKeysFunc | None = None,
# TODO(XuehaiPan): remove these deprecated arguments and remove the type ignore above
flatten_fn: FlattenFunc | None = None,
unflatten_fn: UnflattenFunc | None = None,
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
) -> None:
"""Register a container-like type as pytree node.
@ -211,10 +242,10 @@ def register_pytree_node(
Args:
cls: the type to register
flatten_fn: A callable that takes a pytree and returns a flattened
flatten_func: A callable that takes a pytree and returns a flattened
representation of the pytree and additional context to represent the
flattened pytree.
unflatten_fn: A callable that takes a flattened version of the pytree,
unflatten_func: A callable that takes a flattened version of the pytree,
additional context, and returns an unflattened pytree.
serialized_type_name: A keyword argument used to specify the fully qualified
name used when serializing the tree spec.
@ -226,23 +257,68 @@ def register_pytree_node(
to convert the custom json dumpable representation of the context
back to the original context. This is used for json deserialization,
which is being used in torch.export right now.
flatten_with_keys_fn: An optional keyword argument to specify how to
flatten_with_keys_func: An optional keyword argument to specify how to
access each pytree leaf's keypath when flattening and tree-mapping.
Like ``flatten_fn``, but in place of a List[leaf], it should return
Like ``flatten_func``, but in place of a List[leaf], it should return
a List[(keypath, leaf)].
"""
with _NODE_REGISTRY_LOCK:
if cls in SUPPORTED_NODES:
raise ValueError(f"{cls} is already registered as pytree node.")
if (flatten_func is None) != (unflatten_func is None):
raise ValueError(
"Both flatten_func and unflatten_func must be provided together."
)
if (flatten_fn is None) != (unflatten_fn is None):
raise ValueError("Both flatten_fn and unflatten_fn must be provided together.")
if flatten_func is None and flatten_fn is None:
raise TypeError(
"Missing required argument: 'flatten_func' and 'unflatten_func'."
)
if flatten_func is not None and flatten_fn is not None:
raise ValueError(
"Either (flatten_func, unflatten_func) or (flatten_fn, unflatten_fn) "
"should be provided, not both."
)
if flatten_with_keys_func is not None and flatten_with_keys_fn is not None:
raise ValueError(
"Either flatten_with_keys_func or flatten_with_keys_fn "
"should be provided, not both."
)
if flatten_fn is not None:
warnings.warn(
"The `flatten_fn` and `unflatten_fn` arguments are deprecated. "
"Use `flatten_func` and `unflatten_func` instead. "
"Please consider passing `flatten_func` and `unflatten_func` via positional arguments",
category=FutureWarning,
statcklevel=2,
)
(
(flatten_func, unflatten_func),
(flatten_fn, unflatten_fn),
) = (
(flatten_fn, unflatten_fn),
(None, None),
)
if flatten_with_keys_fn is not None:
warnings.warn(
"The `flatten_with_keys_fn` argument is deprecated. "
"Use `flatten_with_keys_func` instead.",
category=FutureWarning,
stacklevel=2,
)
flatten_with_keys_func, flatten_with_keys_fn = flatten_with_keys_fn, None
_private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
flatten_func,
unflatten_func,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
flatten_with_keys_fn=flatten_with_keys_fn,
flatten_with_keys_func=flatten_with_keys_func,
)
if not _cxx_pytree_exists:
@ -253,14 +329,14 @@ def register_pytree_node(
cxx._private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
flatten_func,
unflatten_func,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
)
else:
args = (cls, flatten_fn, unflatten_fn)
args = (cls, flatten_func, unflatten_func)
kwargs = {
"serialized_type_name": serialized_type_name,
"to_dumpable_context": to_dumpable_context,
@ -380,7 +456,7 @@ def register_dataclass(
_flatten_fn,
_unflatten_fn,
serialized_type_name=serialized_type_name,
flatten_with_keys_fn=_flatten_fn_with_keys,
flatten_with_keys_func=_flatten_fn_with_keys,
)
@ -460,7 +536,7 @@ def register_constant(cls: type[Any]) -> None:
cls,
_flatten,
_unflatten,
flatten_with_keys_fn=_flatten_with_keys,
flatten_with_keys_func=_flatten_with_keys,
)
CONSTANT_NODES.add(cls)
@ -510,7 +586,7 @@ def _register_namedtuple(
serialized_type_name=serialized_type_name,
to_dumpable_context=_namedtuple_serialize,
from_dumpable_context=_namedtuple_deserialize,
flatten_with_keys_fn=_namedtuple_flatten_with_keys,
flatten_with_keys_func=_namedtuple_flatten_with_keys,
)
@ -570,7 +646,7 @@ def _register_pytree_node(
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
flatten_with_keys_fn=flatten_with_keys_fn,
flatten_with_keys_func=flatten_with_keys_fn,
)
@ -590,13 +666,13 @@ def _deregister_pytree_node(
def _private_register_pytree_node(
cls: type[Any],
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
flatten_func: FlattenFunc,
unflatten_func: UnflattenFunc,
*,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
flatten_with_keys_func: FlattenWithKeysFunc | None = None,
) -> None:
"""This is an internal function that is used to register a pytree node type
for the Python pytree only. End-users should use :func:`register_pytree_node`
@ -611,7 +687,7 @@ def _private_register_pytree_node(
stacklevel=2,
)
node_def = NodeDef(cls, flatten_fn, unflatten_fn, flatten_with_keys_fn)
node_def = NodeDef(cls, flatten_func, unflatten_func, flatten_with_keys_func)
SUPPORTED_NODES[cls] = node_def
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
@ -958,21 +1034,21 @@ _private_register_pytree_node(
_tuple_flatten,
_tuple_unflatten,
serialized_type_name="builtins.tuple",
flatten_with_keys_fn=_tuple_flatten_with_keys,
flatten_with_keys_func=_tuple_flatten_with_keys,
)
_private_register_pytree_node(
list,
_list_flatten,
_list_unflatten,
serialized_type_name="builtins.list",
flatten_with_keys_fn=_list_flatten_with_keys,
flatten_with_keys_func=_list_flatten_with_keys,
)
_private_register_pytree_node(
dict,
_dict_flatten,
_dict_unflatten,
serialized_type_name="builtins.dict",
flatten_with_keys_fn=_dict_flatten_with_keys,
flatten_with_keys_func=_dict_flatten_with_keys,
)
_private_register_pytree_node(
namedtuple, # type: ignore[arg-type]
@ -981,14 +1057,14 @@ _private_register_pytree_node(
serialized_type_name="collections.namedtuple",
to_dumpable_context=_namedtuple_serialize,
from_dumpable_context=_namedtuple_deserialize,
flatten_with_keys_fn=_namedtuple_flatten_with_keys,
flatten_with_keys_func=_namedtuple_flatten_with_keys,
)
_private_register_pytree_node(
OrderedDict,
_ordereddict_flatten,
_ordereddict_unflatten,
serialized_type_name="collections.OrderedDict",
flatten_with_keys_fn=_ordereddict_flatten_with_keys,
flatten_with_keys_func=_ordereddict_flatten_with_keys,
)
_private_register_pytree_node(
defaultdict,
@ -997,14 +1073,14 @@ _private_register_pytree_node(
serialized_type_name="collections.defaultdict",
to_dumpable_context=_defaultdict_serialize,
from_dumpable_context=_defaultdict_deserialize,
flatten_with_keys_fn=_defaultdict_flatten_with_keys,
flatten_with_keys_func=_defaultdict_flatten_with_keys,
)
_private_register_pytree_node(
deque,
_deque_flatten,
_deque_unflatten,
serialized_type_name="collections.deque",
flatten_with_keys_fn=_deque_flatten_with_keys,
flatten_with_keys_func=_deque_flatten_with_keys,
)
@ -1186,8 +1262,8 @@ class TreeSpec:
f"Type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
children, context = flatten_fn(tree)
flatten_func = SUPPORTED_NODES[node_type].flatten_func
children, context = flatten_func(tree)
if len(children) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
@ -1238,8 +1314,8 @@ class TreeSpec:
children = [tree[key] for key in expected_keys]
else:
# node_type is treespec.type
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
children, context = flatten_fn(tree)
flatten_func = SUPPORTED_NODES[node_type].flatten_func
children, context = flatten_func(tree)
if (
node_type is not deque # ignore mismatch of `maxlen` for deque
) and context != treespec._context:
@ -1267,7 +1343,7 @@ class TreeSpec:
if self.is_leaf():
return leaves[0]
unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn
unflatten_func = SUPPORTED_NODES[self.type].unflatten_func
# Recursively unflatten the children
start = 0
@ -1278,7 +1354,7 @@ class TreeSpec:
child_pytrees.append(child_spec.unflatten(leaves[start:end]))
start = end
return unflatten_fn(child_pytrees, self._context)
return unflatten_func(child_pytrees, self._context)
def __hash__(self) -> int:
node_type = self.type
@ -1377,8 +1453,8 @@ def tree_flatten(
return _LEAF_SPEC
node_type = _get_node_type(node)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
children, context = flatten_fn(node)
flatten_func = SUPPORTED_NODES[node_type].flatten_func
children, context = flatten_func(node)
# Recursively flatten the children
subspecs = [helper(child, leaves) for child in children]
@ -1410,8 +1486,8 @@ def tree_iter(
yield tree
else:
node_type = _get_node_type(tree)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, _ = flatten_fn(tree)
flatten_func = SUPPORTED_NODES[node_type].flatten_func
child_pytrees, _ = flatten_func(tree)
# Recursively flatten the children
for child in child_pytrees:
@ -1838,8 +1914,8 @@ def _broadcast_to_and_flatten(
if node_type != treespec.type:
return None
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(tree)
flatten_func = SUPPORTED_NODES[node_type].flatten_func
child_pytrees, context = flatten_func(tree)
# Check if the Node is different from the spec
if len(child_pytrees) != treespec.num_children or context != treespec._context:
@ -2106,16 +2182,16 @@ def _generate_key_paths(
yield key_path, tree
return
flatten_with_keys = handler.flatten_with_keys_fn
flatten_with_keys = handler.flatten_with_keys_func
if flatten_with_keys:
key_children, _ = flatten_with_keys(tree)
for k, c in key_children:
yield from _generate_key_paths((*key_path, k), c, is_leaf)
else:
# We registered this pytree but didn't add a flatten_with_keys_fn, complain.
# We registered this pytree but didn't add a flatten_with_keys_func, complain.
raise ValueError(
f"Did not find a flatten_with_keys_fn for type: {node_type}. "
"Please pass a flatten_with_keys_fn argument to register_pytree_node."
f"Did not find a flatten_with_keys_func for type: {node_type}. "
"Please pass a flatten_with_keys_func argument to register_pytree_node."
)