mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-04 16:04:58 +08:00
[pytree] test aligned API signature for C++ and Python pytree (#112485)
Add tests to ensure the C++ and Python pytree provide the same APIs with identical signatures. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112485 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
e6b3a8ce5f
commit
2a3d8e50fb
@ -71,7 +71,7 @@ U = TypeVar("U")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
Context = Optional[Any]
|
||||
Context = Any
|
||||
PyTree = Any
|
||||
TreeSpec = PyTreeSpec
|
||||
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
|
||||
@ -1025,7 +1025,7 @@ def _broadcast_to_and_flatten(
|
||||
return None
|
||||
|
||||
|
||||
def treespec_dumps(treespec: TreeSpec) -> str:
|
||||
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
|
||||
"""Serialize a treespec to a JSON string."""
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
raise TypeError(
|
||||
@ -1038,7 +1038,7 @@ def treespec_dumps(treespec: TreeSpec) -> str:
|
||||
)
|
||||
|
||||
orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec))
|
||||
return _treespec_dumps(orig_treespec)
|
||||
return _treespec_dumps(orig_treespec, protocol=protocol)
|
||||
|
||||
|
||||
def treespec_loads(serialized: str) -> TreeSpec:
|
||||
|
||||
Reference in New Issue
Block a user