mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[pytree] align function signature between C++ and Python pytree (#112482)
Change the argument name in C++ and Python pytree APIs. Also add a test to ensure the function signatures are the same in the two implementations. - #112485 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112482 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
7715b47f44
commit
4893a2814f
@ -58,6 +58,7 @@ __all__ = [
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
U = TypeVar("U")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
@ -79,11 +80,11 @@ def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
|
||||
|
||||
def register_pytree_node(
|
||||
cls: Type[Any],
|
||||
flatten_func: FlattenFunc,
|
||||
unflatten_func: UnflattenFunc,
|
||||
namespace: str = "torch",
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
namespace: str = "torch",
|
||||
) -> None:
|
||||
"""Extend the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
@ -99,20 +100,18 @@ def register_pytree_node(
|
||||
|
||||
Args:
|
||||
cls (type): A Python type to treat as an internal pytree node.
|
||||
flatten_fn (callable): A function to be used during flattening, taking an instance of ``cls``
|
||||
and returning a triple or optionally a pair, with (1) an iterable for the children to be
|
||||
flattened recursively, and (2) some hashable auxiliary data to be stored in the treespec
|
||||
and to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree
|
||||
path entries to the corresponding children. If the entries are not provided or given by
|
||||
:data:`None`, then `range(len(children))` will be used.
|
||||
unflatten_fn (callable): A function taking two arguments: the auxiliary data that was returned
|
||||
by ``flatten_func`` and stored in the treespec, and the unflattened children. The function
|
||||
should return an instance of ``cls``.
|
||||
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
|
||||
type registry. This is used to isolate the registry from other modules that might register
|
||||
a different custom behavior for the same type. (default: :const:`"torch"`)
|
||||
flatten_fn (callable): A function to be used during flattening, taking an instance of
|
||||
``cls`` and returning a pair, with (1) an iterable for the children to be flattened
|
||||
recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
|
||||
passed to the ``unflatten_fn``.
|
||||
unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
|
||||
returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
|
||||
The function should return an instance of ``cls``.
|
||||
serialized_type_name (str, optional): A keyword argument used to specify the fully
|
||||
qualified name used when serializing the tree spec.
|
||||
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
|
||||
type registry. This is used to isolate the registry from other modules that might
|
||||
register a different custom behavior for the same type. (default: :const:`"torch"`)
|
||||
|
||||
Example::
|
||||
|
||||
@ -198,15 +197,15 @@ def register_pytree_node(
|
||||
|
||||
_register_pytree_node(
|
||||
cls,
|
||||
flatten_func,
|
||||
unflatten_func,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
)
|
||||
|
||||
optree.register_pytree_node(
|
||||
cls,
|
||||
flatten_func,
|
||||
_reverse_args(unflatten_func),
|
||||
flatten_fn,
|
||||
_reverse_args(unflatten_fn),
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
@ -219,7 +218,7 @@ def tree_flatten(
|
||||
*,
|
||||
none_is_leaf: bool = True,
|
||||
namespace: str = "torch",
|
||||
) -> Tuple[List[Any], PyTreeSpec]:
|
||||
) -> Tuple[List[Any], TreeSpec]:
|
||||
"""Flatten a pytree.
|
||||
|
||||
See also :func:`tree_unflatten`.
|
||||
@ -269,7 +268,7 @@ def tree_flatten(
|
||||
)
|
||||
|
||||
|
||||
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
|
||||
def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
"""Reconstruct a pytree from the treespec and the leaves.
|
||||
|
||||
The inverse of :func:`tree_flatten`.
|
||||
@ -282,16 +281,16 @@ def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
|
||||
Args:
|
||||
leaves (iterable): The list of leaves to use for reconstruction. The list must match the
|
||||
number of leaves of the treespec.
|
||||
treespec (PyTreeSpec): The treespec to reconstruct.
|
||||
treespec (TreeSpec): The treespec to reconstruct.
|
||||
|
||||
Returns:
|
||||
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
|
||||
``treespec``.
|
||||
"""
|
||||
if not isinstance(treespec, PyTreeSpec):
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
raise TypeError(
|
||||
f"tree_unflatten(values, spec): Expected `spec` to be instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
f"TreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
|
||||
|
||||
@ -337,7 +336,7 @@ def tree_structure(
|
||||
*,
|
||||
none_is_leaf: bool = True,
|
||||
namespace: str = "torch",
|
||||
) -> PyTreeSpec:
|
||||
) -> TreeSpec:
|
||||
"""Get the treespec for a pytree.
|
||||
|
||||
See also :func:`tree_flatten`.
|
||||
@ -464,9 +463,11 @@ def tree_map_(
|
||||
|
||||
|
||||
Type2 = Tuple[Type[T], Type[S]]
|
||||
Type3 = Tuple[Type[T], Type[S], Type[U]]
|
||||
TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
|
||||
|
||||
Fn2 = Callable[[Union[T, S]], R]
|
||||
Fn3 = Callable[[Union[T, S, U]], R]
|
||||
Fn = Callable[[T], R]
|
||||
FnAny = Callable[[Any], R]
|
||||
|
||||
@ -480,6 +481,11 @@ def map_only(__type_or_types: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(__type_or_types: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(__type_or_types: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
|
||||
...
|
||||
@ -547,6 +553,18 @@ def tree_map_only(
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
*rests: PyTree,
|
||||
none_is_leaf: bool = True,
|
||||
namespace: str = "torch",
|
||||
) -> PyTree:
|
||||
...
|
||||
|
||||
|
||||
def tree_map_only(
|
||||
__type_or_types: TypeAny,
|
||||
func: FnAny[Any],
|
||||
@ -588,6 +606,18 @@ def tree_map_only_(
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
*rests: PyTree,
|
||||
none_is_leaf: bool = True,
|
||||
namespace: str = "torch",
|
||||
) -> PyTree:
|
||||
...
|
||||
|
||||
|
||||
def tree_map_only_(
|
||||
__type_or_types: TypeAny,
|
||||
func: FnAny[Any],
|
||||
@ -651,6 +681,18 @@ def tree_all_only(
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def tree_all_only(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
*,
|
||||
none_is_leaf: bool = True,
|
||||
namespace: str = "torch",
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
|
||||
def tree_all_only(
|
||||
__type_or_types: TypeAny,
|
||||
pred: FnAny[bool],
|
||||
@ -687,6 +729,18 @@ def tree_any_only(
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def tree_any_only(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
*,
|
||||
none_is_leaf: bool = True,
|
||||
namespace: str = "torch",
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
|
||||
def tree_any_only(
|
||||
__type_or_types: TypeAny,
|
||||
pred: FnAny[bool],
|
||||
@ -764,12 +818,12 @@ def broadcast_prefix(
|
||||
# _broadcast_to_and_flatten to check this.
|
||||
def _broadcast_to_and_flatten(
|
||||
tree: PyTree,
|
||||
treespec: PyTreeSpec,
|
||||
treespec: TreeSpec,
|
||||
*,
|
||||
none_is_leaf: bool = True,
|
||||
namespace: str = "torch",
|
||||
) -> Optional[List[Any]]:
|
||||
assert isinstance(treespec, PyTreeSpec)
|
||||
assert isinstance(treespec, TreeSpec)
|
||||
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
|
||||
try:
|
||||
return broadcast_prefix(
|
||||
@ -782,12 +836,12 @@ def _broadcast_to_and_flatten(
|
||||
return None
|
||||
|
||||
|
||||
def treespec_dumps(treespec: PyTreeSpec) -> str:
|
||||
def treespec_dumps(treespec: TreeSpec) -> str:
|
||||
"""Serialize a treespec to a JSON string."""
|
||||
if not isinstance(treespec, PyTreeSpec):
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
raise TypeError(
|
||||
f"treespec_dumps(spec): Expected `spec` to be instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
f"TreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
from ._pytree import (
|
||||
tree_structure as _tree_structure,
|
||||
@ -798,7 +852,7 @@ def treespec_dumps(treespec: PyTreeSpec) -> str:
|
||||
return _treespec_dumps(orig_treespec)
|
||||
|
||||
|
||||
def treespec_loads(serialized: str) -> PyTreeSpec:
|
||||
def treespec_loads(serialized: str) -> TreeSpec:
|
||||
"""Deserialize a treespec from a JSON string."""
|
||||
from ._pytree import (
|
||||
tree_unflatten as _tree_unflatten,
|
||||
@ -816,7 +870,7 @@ class _DummyLeaf:
|
||||
return "*"
|
||||
|
||||
|
||||
def treespec_pprint(treespec: PyTreeSpec) -> str:
|
||||
def treespec_pprint(treespec: TreeSpec) -> str:
|
||||
dummy_tree = tree_unflatten(
|
||||
[_DummyLeaf() for _ in range(treespec.num_leaves)],
|
||||
treespec,
|
||||
@ -824,14 +878,11 @@ def treespec_pprint(treespec: PyTreeSpec) -> str:
|
||||
return repr(dummy_tree)
|
||||
|
||||
|
||||
class PyTreeLeafSpecMeta(type(PyTreeSpec)): # type: ignore[misc]
|
||||
class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc]
|
||||
def __instancecheck__(self, instance: object) -> bool:
|
||||
return isinstance(instance, PyTreeSpec) and instance.is_leaf()
|
||||
return isinstance(instance, TreeSpec) and instance.is_leaf()
|
||||
|
||||
|
||||
class PyTreeLeafSpec(PyTreeSpec, metaclass=PyTreeLeafSpecMeta):
|
||||
def __new__(cls, none_is_leaf: bool = True) -> "PyTreeLeafSpec":
|
||||
class LeafSpec(TreeSpec, metaclass=LeafSpecMeta):
|
||||
def __new__(cls, none_is_leaf: bool = True) -> "LeafSpec":
|
||||
return optree.treespec_leaf(none_is_leaf=none_is_leaf) # type: ignore[return-value]
|
||||
|
||||
|
||||
LeafSpec = PyTreeLeafSpec
|
||||
|
Reference in New Issue
Block a user