[pytree] test aligned API signature for C++ and Python pytree (#112485)

Add tests to ensure the C++ and Python pytree provide the same APIs with identical signatures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112485
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan
2023-11-30 22:55:35 +08:00
committed by PyTorch MergeBot
parent e6b3a8ce5f
commit 2a3d8e50fb
6 changed files with 110 additions and 15 deletions

View File

@ -71,7 +71,7 @@ U = TypeVar("U")
R = TypeVar("R")
Context = Optional[Any]
Context = Any
PyTree = Any
TreeSpec = PyTreeSpec
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
@ -1025,7 +1025,7 @@ def _broadcast_to_and_flatten(
return None
def treespec_dumps(treespec: TreeSpec) -> str:
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
"""Serialize a treespec to a JSON string."""
if not isinstance(treespec, TreeSpec):
raise TypeError(
@ -1038,7 +1038,7 @@ def treespec_dumps(treespec: TreeSpec) -> str:
)
orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec))
return _treespec_dumps(orig_treespec)
return _treespec_dumps(orig_treespec, protocol=protocol)
def treespec_loads(serialized: str) -> TreeSpec: