mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytree] add access api (#117771)
This PR introduces an API to use KeyPaths to actually access values on pytrees. Differential Revision: [D52881260](https://our.internmc.facebook.com/intern/diff/D52881260/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/117771 Approved by: https://github.com/zou3519, https://github.com/XuehaiPan
This commit is contained in:
@ -35,7 +35,7 @@ if torch._running_with_deploy():
|
||||
import optree
|
||||
from optree import PyTreeSpec # direct import for type annotations
|
||||
|
||||
from torch.utils._pytree import PHashable
|
||||
from torch.utils._pytree import KeyEntry
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -49,6 +49,7 @@ __all__ = [
|
||||
"TreeSpec",
|
||||
"LeafSpec",
|
||||
"keystr",
|
||||
"key_get",
|
||||
"register_pytree_node",
|
||||
"tree_flatten",
|
||||
"tree_flatten_with_path",
|
||||
@ -86,7 +87,6 @@ OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
|
||||
DumpableContext = Any # Any json dumpable text
|
||||
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
||||
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
||||
KeyEntry = PHashable
|
||||
KeyPath = Tuple[KeyEntry, ...]
|
||||
FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]]
|
||||
|
||||
@ -919,3 +919,8 @@ def tree_map_with_path(
|
||||
def keystr(kp: KeyPath) -> str:
|
||||
"""Given a key path, return a pretty-printed representation."""
|
||||
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
|
||||
|
||||
|
||||
def key_get(obj: Any, kp: KeyPath) -> Any:
|
||||
"""Given an object and a key path, return the value at the key path."""
|
||||
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
|
||||
|
Reference in New Issue
Block a user