mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Split from #92679. Use C++-based PyTree implementation. ## Highlights 1. High performance (20x speedup than the pure-Python implementation, 10%-20% overall speedup for `torch.fx`) 2. Multi-input tree-map support 3. Custom tree node registry with namespace isolation Refs: - #65761 - #91323 - #92679 From https://github.com/pytorch/pytorch/issues/65761#issuecomment-1334746366: > ### 0. Out-of-box compatible with JAX's pytree, provides the same interfaces and functions (and more). > > ### 1. High-performance: `optree` has comparable fast tree operations (~0.9x for `dict`s and ~2.5x for `OrderedDict`s) than JAX's pytree and it is 20x faster than `torch.utils._pytree`. > > `optree` implements some common Python container types in C++ (e.g., `OrderedDict`) and achieves 2.5x performance than JAX's pytree. Check out section [Built-in PyTree Node Types](https://github.com/metaopt/optree#built-in-pytree-node-types) and [Benchmark](https://github.com/metaopt/optree#benchmark) for more details. > > | Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) | > | :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: | > | TinyMLP | 53 | 26.40 | 68.19 | 586.87 | 34.14 | 2.58 | 22.23 | 1.29 | > | AlexNet | 188 | 84.28 | 259.51 | 2182.07 | 125.12 | 3.08 | 25.89 | 1.48 | > | ResNet18 | 698 | 288.57 | 807.27 | 7881.69 | 429.39 | 2.80 | 27.31 | 1.49 | > | ResNet34 | 1242 | 580.75 | 1564.97 | 15082.84 | 819.02 | 2.69 | 25.97 | 1.41 | > | ResNet50 | 1702 | 791.18 | 2081.17 | 20982.82 | 1104.62 | 2.63 | 26.52 | 1.40 | > | ResNet101 | 3317 | 1603.93 | 3939.37 | 40382.14 | 2208.63 | 2.46 | 25.18 | 1.38 | > | ResNet152 | 4932 | 2446.56 | 6267.98 | 56892.36 | 3139.17 | 2.56 | 23.25 | 1.28 | > | ViT-H/14 | 3420 | 1681.48 | 4488.33 | 41703.16 | 2504.86 | 2.67 | 24.80 | 1.49 | > | Swin-B | 2881 | 1565.41 | 4091.10 | 34241.99 | 1936.75 | 2.61 | 21.87 | 1.24 | > | | | | | | **Average** | **2.68** | **24.78** | **1.38** | > > <div align="center"> > <img src="https://user-images.githubusercontent.com/16078332/200494435-fd5bb385-59f7-4811-b520-98bf5763ccf3.png" width="90%" /> > </div> > > ### 2. Namespace Isolation for the PyTree Type Registry > > In addition to the JAX's pytree registry for custom node type registration, `optree` adds `namespace` isolation to the registry. Users can register the same type multiple times for different flatten/unflatten behavior. It also provides module-level isolation for safety reasons. For example, you can add a unique prefix to your namespace to isolate your registry with other modules (e.g., `torch.xxx`, `torch.functorch.xxx`): > > ```python > # Register a Python type into a namespace > import torch > > optree.register_pytree_node( > torch.Tensor, > # (tensor) -> (children, metadata) > flatten_func=lambda tensor: ( > (tensor.cpu().numpy(),), > dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad), > ), > # (metadata, children) -> tensor > unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata), > namespace='torch.torch2numpy', > ) > ``` > > ```python > >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))} > >>> tree > {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])} > > # Flatten without specifying the namespace > >>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes > ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *})) > > # Flatten with the namespace > >>> leaves, treespec = optree.tree_flatten(tree, namespace='torch.torch2numpy') > >>> leaves, treespec > ( > [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)], > PyTreeSpec( > { > 'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]), > 'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*]) > }, > namespace='torch.torch2numpy' > ) > ) > > # `entries` are not defined and use `range(len(children))` > >>> optree.tree_paths(tree, namespace='torch.torch2numpy') > [('bias', 0), ('weight', 0)] > > # Unflatten back to a copy of the original object > >>> optree.tree_unflatten(treespec, leaves) > {'bias': tensor([0., 0.]), 'weight': tensor([[1., 1.]], device='cuda:0')} > ``` > > Check out section [Registering a Container-like Custom Type as Non-leaf Nodes](https://github.com/metaopt/optree#notes-about-the-pytree-type-registry) for more details. > > ### 3. Support both `None` as Non-leaf Node and `None` as Leaf > > In JAX's implementation, `None` is always an internal non-leaf node with an arity 0, which is like an empty tuple. This limits the usage of the JAX's pytree utilities for PyTorch. For example, the `nn.Module` uses `_parameters` and `_buffers` (`OrderedDict[str, Optional[Tensor]]`) to hold the tensors, while the value can be a tensor or `None`. > > `optree` supports both `None` as Non-leaf Node (JAX's default) and `None` as Leaf (PyTorch's default). Check out section [None is Non-leaf Node vs. None is Leaf](https://github.com/metaopt/optree#none-is-non-leaf-node-vs-none-is-leaf) for more details. > > ### 4. Some other improvements and bug fixes > > 1. Adds in-place version of treemap (`tree_map_`), which reduces redundant unflatten operation for better performance. > 2. Adds support for tree flatten and tree map with paths. (useful for `functorch` module extraction). > 3. Improves the JAX's pytree sorting support for `dict`s. > 4. Better string representation `repr(PyTreeSpec)`. > 5. Fixes some bugs for JAX's pytree of hashing, pickle serialization, segmentation fault for infinite recursion, and tree-compose/tree-transpose. From https://github.com/pytorch/pytorch/pull/92679#issuecomment-1398778481: > ```python > # pytree_make_fx_bench.py > import torch > from torch.fx.experimental.proxy_tensor import make_fx > import time > > def f(x): > for _ in range(10000): > x = x+x > return x > > import time > begin = time.time() > out = make_fx(f, tracing_mode="real")(torch.randn(20)) > begin = time.time() > print(f'tracing_mode="real" {time.time() - begin:.2f}') > out = make_fx(f, tracing_mode="fake")(torch.randn(20)) > print(f'tracing_mode="fake" {time.time() - begin:.2f}') > > out = make_fx(f, tracing_mode="symbolic")(torch.randn(20)) > print(f'tracing_mode="symbolic" {time.time() - begin:.2f}') > ``` > > This seems to run around 10-20% faster with the optree implementation: > > ``` > # Optree > python pytree_make_fx_bench.py > tracing_mode="real" 0.00 > tracing_mode="fake" 6.32 > tracing_mode="symbolic" 27.13 > ``` > > ``` > # torch.utils._pytree > python pytree_make_fx_bench.py > tracing_mode="real" 0.00 > tracing_mode="fake" 7.66 > tracing_mode="symbolic" 31.07 > ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/93139 Approved by: https://github.com/malfet
549 lines
19 KiB
Python
549 lines
19 KiB
Python
from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, Optional, TypeVar, overload, Union
|
|
import functools
|
|
from collections import namedtuple, OrderedDict
|
|
import dataclasses
|
|
import json
|
|
import warnings
|
|
|
|
|
|
T = TypeVar('T')
|
|
S = TypeVar('S')
|
|
U = TypeVar('U')
|
|
R = TypeVar('R')
|
|
|
|
"""
|
|
Contains utility functions for working with nested python data structures.
|
|
|
|
A *pytree* is Python nested data structure. It is a tree in the sense that
|
|
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
|
|
Python values. Furthermore, a pytree should not contain reference cycles.
|
|
|
|
pytrees are useful for working with nested collections of Tensors. For example,
|
|
one can use `tree_map` to map a function over all Tensors inside some nested
|
|
collection of Tensors and `tree_unflatten` to get a flat list of all Tensors
|
|
inside some nested collection. pytrees are helpful for implementing nested
|
|
collection support for PyTorch APIs.
|
|
|
|
This pytree implementation is not very performant due to Python overhead
|
|
To improve the performance we can move parts of the implementation to C++.
|
|
"""
|
|
|
|
DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1
|
|
|
|
Context = Any
|
|
PyTree = Any
|
|
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
|
|
UnflattenFunc = Callable[[List, Context], PyTree]
|
|
DumpableContext = Any # Any json dumpable text
|
|
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
|
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
|
ToStrFunc = Callable[["TreeSpec", List[str]], str]
|
|
MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]]
|
|
|
|
# A NodeDef holds two callables:
|
|
# - flatten_fn should take the collection and return a flat list of values.
|
|
# It can also return some context that is used in reconstructing the
|
|
# collection.
|
|
# - unflatten_fn should take a flat list of values and some context
|
|
# (returned by flatten_fn). It returns the collection by reconstructing
|
|
# it from the list and the context.
|
|
class NodeDef(NamedTuple):
|
|
type: Type[Any]
|
|
flatten_fn: FlattenFunc
|
|
unflatten_fn: UnflattenFunc
|
|
|
|
SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
|
|
|
|
# _SerializeNodeDef holds the following:
|
|
# - typ: the type of the node (e.g., "Dict", "List", etc)
|
|
# - type_fqn: the fully qualified name of the type, e.g. "collections.OrderedDict"
|
|
# - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the
|
|
# context, and the version number
|
|
# - from_dumpable_context takes in a string representation of the context, and the
|
|
# version, and returns the deserialized context
|
|
class _SerializeNodeDef(NamedTuple):
|
|
typ: Type[Any]
|
|
type_fqn: str
|
|
to_dumpable_context: Optional[ToDumpableContextFn]
|
|
from_dumpable_context: Optional[FromDumpableContextFn]
|
|
|
|
SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {}
|
|
SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {}
|
|
|
|
def _register_pytree_node(
|
|
typ: Any,
|
|
flatten_fn: FlattenFunc,
|
|
unflatten_fn: UnflattenFunc,
|
|
to_str_fn: Optional[ToStrFunc] = None,
|
|
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None,
|
|
*,
|
|
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
|
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
typ: the type to register
|
|
flatten_fn: A callable that takes a pytree and returns a flattened
|
|
representation of the pytree and additional context to represent the
|
|
flattened pytree.
|
|
unflatten_fn: A callable that takes a flattened version of the pytree,
|
|
additional context, and returns an unflattedn pytree.
|
|
to_dumpable_context: An optional keyword argument to custom specify how
|
|
to convert the context of the pytree to a custom json dumpable
|
|
representation. This is used for json serialization, which is being
|
|
used in torch.export right now.
|
|
from_dumpable_context: An optional keyword argument to custom specify how
|
|
to convert the custom json dumpable representation of the context
|
|
back to the original context. This is used for json deserialization,
|
|
which is being used in torch.export right now.
|
|
"""
|
|
if to_str_fn is not None or maybe_from_str_fn is not None:
|
|
warnings.warn(
|
|
"to_str_fn and maybe_from_str_fn is deprecated. "
|
|
"Please use to_dumpable_context and from_dumpable_context instead."
|
|
)
|
|
|
|
node_def = NodeDef(
|
|
typ,
|
|
flatten_fn,
|
|
unflatten_fn,
|
|
)
|
|
SUPPORTED_NODES[typ] = node_def
|
|
|
|
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
|
|
raise ValueError(
|
|
f"Both to_dumpable_context and from_dumpable_context for {typ} must "
|
|
"be None or registered."
|
|
)
|
|
|
|
type_fqn = f"{typ.__module__}.{typ.__name__}"
|
|
serialize_node_def = _SerializeNodeDef(
|
|
typ, type_fqn, to_dumpable_context, from_dumpable_context
|
|
)
|
|
SUPPORTED_SERIALIZED_TYPES[typ] = serialize_node_def
|
|
SERIALIZED_TYPE_TO_PYTHON_TYPE[type_fqn] = typ
|
|
|
|
|
|
def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
|
|
return list(d.values()), list(d.keys())
|
|
|
|
def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
|
|
return dict(zip(context, values))
|
|
|
|
def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
|
|
return d, None
|
|
|
|
def _list_unflatten(values: List[Any], context: Context) -> List[Any]:
|
|
return list(values)
|
|
|
|
def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]:
|
|
return list(d), None
|
|
|
|
def _tuple_unflatten(values: List[Any], context: Context) -> Tuple[Any, ...]:
|
|
return tuple(values)
|
|
|
|
def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]:
|
|
return list(d), type(d)
|
|
|
|
def _namedtuple_unflatten(values: List[Any], context: Context) -> NamedTuple:
|
|
return cast(NamedTuple, context(*values))
|
|
|
|
def _namedtuple_serialize(context: Context) -> DumpableContext:
|
|
json_namedtuple = {
|
|
"class_name": context.__name__,
|
|
"fields": context._fields,
|
|
}
|
|
return json_namedtuple
|
|
|
|
def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context:
|
|
class_name = dumpable_context["class_name"]
|
|
assert isinstance(class_name, str)
|
|
context = namedtuple(class_name, dumpable_context["fields"]) # type: ignore[misc]
|
|
return context
|
|
|
|
def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Context]:
|
|
return list(d.values()), list(d.keys())
|
|
|
|
def _odict_unflatten(values: List[Any], context: Context) -> 'OrderedDict[Any, Any]':
|
|
return OrderedDict((key, value) for key, value in zip(context, values))
|
|
|
|
|
|
_register_pytree_node(dict, _dict_flatten, _dict_unflatten)
|
|
_register_pytree_node(list, _list_flatten, _list_unflatten)
|
|
_register_pytree_node(tuple, _tuple_flatten, _tuple_unflatten)
|
|
_register_pytree_node(
|
|
namedtuple,
|
|
_namedtuple_flatten,
|
|
_namedtuple_unflatten,
|
|
to_dumpable_context=_namedtuple_serialize,
|
|
from_dumpable_context=_namedtuple_deserialize,
|
|
)
|
|
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
|
|
|
|
|
|
# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
|
|
def _is_namedtuple_instance(pytree: Any) -> bool:
|
|
typ = type(pytree)
|
|
bases = typ.__bases__
|
|
if len(bases) != 1 or bases[0] != tuple:
|
|
return False
|
|
fields = getattr(typ, '_fields', None)
|
|
if not isinstance(fields, tuple):
|
|
return False
|
|
return all(type(entry) == str for entry in fields)
|
|
|
|
def _get_node_type(pytree: Any) -> Any:
|
|
if _is_namedtuple_instance(pytree):
|
|
return namedtuple
|
|
return type(pytree)
|
|
|
|
# A leaf is defined as anything that is not a Node.
|
|
def _is_leaf(pytree: PyTree) -> bool:
|
|
return _get_node_type(pytree) not in SUPPORTED_NODES
|
|
|
|
|
|
# A TreeSpec represents the structure of a pytree. It holds:
|
|
# "type": the type of root Node of the pytree
|
|
# context: some context that is useful in unflattening the pytree
|
|
# children_specs: specs for each child of the root Node
|
|
# num_leaves: the number of leaves
|
|
@dataclasses.dataclass
|
|
class TreeSpec:
|
|
type: Any
|
|
context: Context
|
|
children_specs: List['TreeSpec']
|
|
|
|
def __post_init__(self) -> None:
|
|
self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs])
|
|
|
|
def __repr__(self, indent: int = 0) -> str:
|
|
repr_prefix: str = f'TreeSpec({self.type.__name__}, {self.context}, ['
|
|
children_specs_str: str = ''
|
|
if len(self.children_specs):
|
|
indent += 2
|
|
children_specs_str += self.children_specs[0].__repr__(indent)
|
|
children_specs_str += ',' if len(self.children_specs) > 1 else ''
|
|
children_specs_str += ','.join(['\n' + ' ' * indent + child.__repr__(indent) for child in self.children_specs[1:]])
|
|
repr_suffix: str = f'{children_specs_str}])'
|
|
return repr_prefix + repr_suffix
|
|
|
|
|
|
class LeafSpec(TreeSpec):
|
|
def __init__(self) -> None:
|
|
super().__init__(None, None, [])
|
|
self.num_leaves = 1
|
|
|
|
def __repr__(self, indent: int = 0) -> str:
|
|
return '*'
|
|
|
|
def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]:
|
|
"""Flattens a pytree into a list of values and a TreeSpec that can be used
|
|
to reconstruct the pytree.
|
|
"""
|
|
if _is_leaf(pytree):
|
|
return [pytree], LeafSpec()
|
|
|
|
node_type = _get_node_type(pytree)
|
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
|
child_pytrees, context = flatten_fn(pytree)
|
|
|
|
# Recursively flatten the children
|
|
result : List[Any] = []
|
|
children_specs : List[TreeSpec] = []
|
|
for child in child_pytrees:
|
|
flat, child_spec = tree_flatten(child)
|
|
result += flat
|
|
children_specs.append(child_spec)
|
|
|
|
return result, TreeSpec(node_type, context, children_specs)
|
|
|
|
|
|
def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree:
|
|
"""Given a list of values and a TreeSpec, builds a pytree.
|
|
This is the inverse operation of `tree_flatten`.
|
|
"""
|
|
if not isinstance(spec, TreeSpec):
|
|
raise TypeError(
|
|
f'tree_unflatten(values, spec): Expected `spec` to be instance of '
|
|
f'TreeSpec but got item of type {type(spec)}.')
|
|
if len(values) != spec.num_leaves:
|
|
raise ValueError(
|
|
f'tree_unflatten(values, spec): `values` has length {len(values)} '
|
|
f'but the spec refers to a pytree that holds {spec.num_leaves} '
|
|
f'items ({spec}).')
|
|
if isinstance(spec, LeafSpec):
|
|
return values[0]
|
|
|
|
unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn
|
|
|
|
# Recursively unflatten the children
|
|
start = 0
|
|
end = 0
|
|
child_pytrees = []
|
|
for child_spec in spec.children_specs:
|
|
end += child_spec.num_leaves
|
|
child_pytrees.append(tree_unflatten(values[start:end], child_spec))
|
|
start = end
|
|
|
|
return unflatten_fn(child_pytrees, spec.context)
|
|
|
|
def tree_map(fn: Any, pytree: PyTree) -> PyTree:
|
|
flat_args, spec = tree_flatten(pytree)
|
|
return tree_unflatten([fn(i) for i in flat_args], spec)
|
|
|
|
Type2 = Tuple[Type[T], Type[S]]
|
|
Type3 = Tuple[Type[T], Type[S], Type[U]]
|
|
TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
|
|
|
|
Fn3 = Callable[[Union[T, S, U]], R]
|
|
Fn2 = Callable[[Union[T, S]], R]
|
|
Fn = Callable[[T], R]
|
|
FnAny = Callable[[Any], R]
|
|
|
|
MapOnlyFn = Callable[[T], Callable[[Any], Any]]
|
|
|
|
# These specializations help with type inference on the lambda passed to this
|
|
# function
|
|
@overload
|
|
def map_only(ty: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
|
|
...
|
|
|
|
@overload
|
|
def map_only(ty: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
|
|
...
|
|
|
|
# This specialization is needed for the implementations below that call
|
|
@overload
|
|
def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
|
...
|
|
|
|
def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
|
"""
|
|
Suppose you are writing a tree_map over tensors, leaving everything
|
|
else unchanged. Ordinarily you would have to write:
|
|
|
|
def go(t):
|
|
if isinstance(t, Tensor):
|
|
return ...
|
|
else:
|
|
return t
|
|
|
|
With this function, you only need to write:
|
|
|
|
@map_only(Tensor)
|
|
def go(t):
|
|
return ...
|
|
|
|
You can also directly use 'tree_map_only'
|
|
"""
|
|
def deco(f: Callable[[T], Any]) -> Callable[[Any], Any]:
|
|
@functools.wraps(f)
|
|
def inner(x: T) -> Any:
|
|
if isinstance(x, ty):
|
|
return f(x)
|
|
else:
|
|
return x
|
|
return inner
|
|
return deco
|
|
|
|
@overload
|
|
def tree_map_only(ty: Type[T], fn: Fn[T, Any], pytree: PyTree) -> PyTree:
|
|
...
|
|
|
|
@overload
|
|
def tree_map_only(ty: Type2[T, S], fn: Fn2[T, S, Any], pytree: PyTree) -> PyTree:
|
|
...
|
|
|
|
@overload
|
|
def tree_map_only(ty: Type3[T, S, U], fn: Fn3[T, S, U, Any], pytree: PyTree) -> PyTree:
|
|
...
|
|
|
|
def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree:
|
|
return tree_map(map_only(ty)(fn), pytree)
|
|
|
|
def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool:
|
|
flat_args, _ = tree_flatten(pytree)
|
|
return all(map(pred, flat_args))
|
|
|
|
def tree_any(pred: Callable[[Any], bool], pytree: PyTree) -> bool:
|
|
flat_args, _ = tree_flatten(pytree)
|
|
return any(map(pred, flat_args))
|
|
|
|
@overload
|
|
def tree_all_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool:
|
|
...
|
|
|
|
@overload
|
|
def tree_all_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool:
|
|
...
|
|
|
|
@overload
|
|
def tree_all_only(ty: Type3[T, S, U], pred: Fn3[T, S, U, bool], pytree: PyTree) -> bool:
|
|
...
|
|
|
|
def tree_all_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool:
|
|
flat_args, _ = tree_flatten(pytree)
|
|
return all(pred(x) for x in flat_args if isinstance(x, ty))
|
|
|
|
@overload
|
|
def tree_any_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool:
|
|
...
|
|
|
|
@overload
|
|
def tree_any_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool:
|
|
...
|
|
|
|
def tree_any_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool:
|
|
flat_args, _ = tree_flatten(pytree)
|
|
return any(pred(x) for x in flat_args if isinstance(x, ty))
|
|
|
|
# Broadcasts a pytree to the provided TreeSpec and returns the flattened
|
|
# values. If this is not possible, then this function returns None.
|
|
#
|
|
# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
|
|
# would return [0, 0]. This is useful for part of the vmap implementation:
|
|
# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
|
|
# broadcastable to the tree structure of `inputs` and we use
|
|
# _broadcast_to_and_flatten to check this.
|
|
def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[Any]]:
|
|
assert isinstance(spec, TreeSpec)
|
|
|
|
if _is_leaf(pytree):
|
|
return [pytree] * spec.num_leaves
|
|
if isinstance(spec, LeafSpec):
|
|
return None
|
|
node_type = _get_node_type(pytree)
|
|
if node_type != spec.type:
|
|
return None
|
|
|
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
|
child_pytrees, ctx = flatten_fn(pytree)
|
|
|
|
# Check if the Node is different from the spec
|
|
if len(child_pytrees) != len(spec.children_specs) or ctx != spec.context:
|
|
return None
|
|
|
|
# Recursively flatten the children
|
|
result : List[Any] = []
|
|
for child, child_spec in zip(child_pytrees, spec.children_specs):
|
|
flat = _broadcast_to_and_flatten(child, child_spec)
|
|
if flat is not None:
|
|
result += flat
|
|
else:
|
|
return None
|
|
|
|
return result
|
|
|
|
|
|
"""
|
|
_TreeSpecSchema is the schema used to serialize the TreeSpec
|
|
It contains the following fields:
|
|
- type: A string name of the type. null for the case of a LeafSpec.
|
|
- context: Any format which is json dumpable
|
|
- children_spec: A list of children serialized specs.
|
|
"""
|
|
@dataclasses.dataclass
|
|
class _TreeSpecSchema:
|
|
type: Optional[str]
|
|
context: DumpableContext
|
|
children_spec: List['_TreeSpecSchema']
|
|
|
|
class _ProtocolFn(NamedTuple):
|
|
treespec_to_json: Callable[[TreeSpec], DumpableContext]
|
|
json_to_treespec: Callable[[DumpableContext], TreeSpec]
|
|
|
|
_SUPPORTED_PROTOCOLS: Dict[int, _ProtocolFn] = {}
|
|
|
|
|
|
def _treespec_to_json(spec: TreeSpec) -> _TreeSpecSchema:
|
|
if isinstance(spec, LeafSpec):
|
|
return _TreeSpecSchema(None, None, [])
|
|
|
|
if spec.type not in SUPPORTED_SERIALIZED_TYPES:
|
|
raise NotImplementedError(f"Serializing {spec.type} in pytree is not registered.")
|
|
|
|
serialize_node_def = SUPPORTED_SERIALIZED_TYPES[spec.type]
|
|
|
|
type_fqn = serialize_node_def.type_fqn
|
|
|
|
if serialize_node_def.to_dumpable_context is None:
|
|
try:
|
|
serialized_context = json.dumps(spec.context)
|
|
except TypeError as e:
|
|
raise TypeError(
|
|
"Unable to serialize context. "
|
|
"Please make the context json dump-able, or register a "
|
|
"custom serializer using _register_pytree_node."
|
|
) from e
|
|
else:
|
|
serialized_context = serialize_node_def.to_dumpable_context(spec.context)
|
|
|
|
child_schemas = [_treespec_to_json(child) for child in spec.children_specs]
|
|
|
|
return _TreeSpecSchema(type_fqn, serialized_context, child_schemas)
|
|
|
|
def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
|
|
if (
|
|
json_schema["type"] is None and
|
|
json_schema["context"] is None and
|
|
len(json_schema["children_spec"]) == 0
|
|
):
|
|
return LeafSpec()
|
|
|
|
if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
|
|
raise NotImplementedError(f'Deserializing {json_schema["type"]} in pytree is not registered.')
|
|
|
|
typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]]
|
|
serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ]
|
|
|
|
if serialize_node_def.from_dumpable_context is None:
|
|
try:
|
|
context = json.loads(json_schema["context"])
|
|
except TypeError:
|
|
raise TypeError(
|
|
"Unable to deserialize context. "
|
|
"Please make the context json load-able, or register a "
|
|
"custom serializer using _register_pytree_node."
|
|
)
|
|
else:
|
|
context = serialize_node_def.from_dumpable_context(json_schema["context"])
|
|
|
|
children_spec = []
|
|
for child_string in json_schema["children_spec"]:
|
|
children_spec.append(_json_to_treespec(child_string))
|
|
|
|
return TreeSpec(typ, context, children_spec)
|
|
|
|
|
|
_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
|
|
|
|
|
|
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
|
|
if protocol is None:
|
|
protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL
|
|
|
|
if protocol in _SUPPORTED_PROTOCOLS:
|
|
json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec)
|
|
else:
|
|
raise ValueError(f"Unknown protocol {protocol}. Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}")
|
|
|
|
str_spec = json.dumps((protocol, dataclasses.asdict(json_spec)))
|
|
return str_spec
|
|
|
|
def treespec_loads(data: str) -> TreeSpec:
|
|
protocol, json_schema = json.loads(data)
|
|
|
|
if protocol in _SUPPORTED_PROTOCOLS:
|
|
return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema)
|
|
raise ValueError(f"Unknown protocol {protocol}. Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}")
|
|
|
|
# TODO(angelayi): remove this function after OSS/internal stabilize
|
|
def pytree_to_str(spec: TreeSpec) -> str:
|
|
warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps")
|
|
return treespec_dumps(spec)
|
|
|
|
# TODO(angelayi): remove this function after OSS/internal stabilize
|
|
def str_to_pytree(json: str) -> TreeSpec:
|
|
warnings.warn("str_to_pytree is deprecated. Please use treespec_loads")
|
|
return treespec_loads(json)
|