[ghstack-poisoned]
This commit is contained in:
Xuehai Pan
2024-10-16 16:34:57 +08:00

View File

@ -866,6 +866,8 @@ def _broadcast_to_and_flatten(
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
"""Serialize a treespec to a JSON string."""
if 'treespec' not in type(treespec).__name__:
raise TypeError("treespec_dumps expects a treespec object.")
dummy_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
orig_treespec = python_pytree.tree_structure(dummy_tree)
return python_pytree.treespec_dumps(orig_treespec, protocol=protocol)