[pytree] add access api (#117771)

This PR introduces an API to use KeyPaths to actually access values on pytrees.

Differential Revision: [D52881260](https://our.internmc.facebook.com/intern/diff/D52881260/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117771
Approved by: https://github.com/zou3519, https://github.com/XuehaiPan
This commit is contained in:
suo
2024-01-19 12:32:48 -08:00
committed by PyTorch MergeBot
parent a1b3b5748f
commit e732adf0a7
3 changed files with 50 additions and 7 deletions

View File

@ -35,7 +35,7 @@ if torch._running_with_deploy():
import optree
from optree import PyTreeSpec # direct import for type annotations
from torch.utils._pytree import PHashable
from torch.utils._pytree import KeyEntry
__all__ = [
@ -49,6 +49,7 @@ __all__ = [
"TreeSpec",
"LeafSpec",
"keystr",
"key_get",
"register_pytree_node",
"tree_flatten",
"tree_flatten_with_path",
@ -86,7 +87,6 @@ OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
DumpableContext = Any # Any json dumpable text
ToDumpableContextFn = Callable[[Context], DumpableContext]
FromDumpableContextFn = Callable[[DumpableContext], Context]
KeyEntry = PHashable
KeyPath = Tuple[KeyEntry, ...]
FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]]
@ -919,3 +919,8 @@ def tree_map_with_path(
def keystr(kp: KeyPath) -> str:
"""Given a key path, return a pretty-printed representation."""
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
def key_get(obj: Any, kp: KeyPath) -> Any:
"""Given an object and a key path, return the value at the key path."""
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")