[pytree] Use OpTree for PyTree manipulation (#93139)

Split from #92679. Use C++-based PyTree implementation.

## Highlights

1. High performance (20x speedup than the pure-Python implementation, 10%-20% overall speedup for `torch.fx`)
2. Multi-input tree-map support
3. Custom tree node registry with namespace isolation

Refs:

- #65761
- #91323
- #92679

From https://github.com/pytorch/pytorch/issues/65761#issuecomment-1334746366:

> ### 0. Out-of-box compatible with JAX's pytree, provides the same interfaces and functions (and more).
>
> ### 1. High-performance: `optree` has comparable fast tree operations (~0.9x for `dict`s and ~2.5x for `OrderedDict`s) than JAX's pytree and it is 20x faster than `torch.utils._pytree`.
>
> `optree` implements some common Python container types in C++ (e.g., `OrderedDict`) and achieves 2.5x performance than JAX's pytree. Check out section [Built-in PyTree Node Types](https://github.com/metaopt/optree#built-in-pytree-node-types) and [Benchmark](https://github.com/metaopt/optree#benchmark) for more details.
>
> | Module    | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
> | :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
> | TinyMLP   |    53 |       26.40 |        68.19 |       586.87 |        34.14 |            2.58 |           22.23 |            1.29 |
> | AlexNet   |   188 |       84.28 |       259.51 |      2182.07 |       125.12 |            3.08 |           25.89 |            1.48 |
> | ResNet18  |   698 |      288.57 |       807.27 |      7881.69 |       429.39 |            2.80 |           27.31 |            1.49 |
> | ResNet34  |  1242 |      580.75 |      1564.97 |     15082.84 |       819.02 |            2.69 |           25.97 |            1.41 |
> | ResNet50  |  1702 |      791.18 |      2081.17 |     20982.82 |      1104.62 |            2.63 |           26.52 |            1.40 |
> | ResNet101 |  3317 |     1603.93 |      3939.37 |     40382.14 |      2208.63 |            2.46 |           25.18 |            1.38 |
> | ResNet152 |  4932 |     2446.56 |      6267.98 |     56892.36 |      3139.17 |            2.56 |           23.25 |            1.28 |
> | ViT-H/14  |  3420 |     1681.48 |      4488.33 |     41703.16 |      2504.86 |            2.67 |           24.80 |            1.49 |
> | Swin-B    |  2881 |     1565.41 |      4091.10 |     34241.99 |      1936.75 |            2.61 |           21.87 |            1.24 |
> |           |       |             |              |              |  **Average** |        **2.68** |       **24.78** |        **1.38** |
>
> <div align="center">
>   <img src="https://user-images.githubusercontent.com/16078332/200494435-fd5bb385-59f7-4811-b520-98bf5763ccf3.png" width="90%" />
> </div>
>
> ### 2. Namespace Isolation for the PyTree Type Registry
>
> In addition to the JAX's pytree registry for custom node type registration, `optree` adds `namespace` isolation to the registry. Users can register the same type multiple times for different flatten/unflatten behavior. It also provides module-level isolation for safety reasons. For example, you can add a unique prefix to your namespace to isolate your registry with other modules (e.g., `torch.xxx`, `torch.functorch.xxx`):
>
> ```python
> # Register a Python type into a namespace
> import torch
>
> optree.register_pytree_node(
>     torch.Tensor,
>     # (tensor) -> (children, metadata)
>     flatten_func=lambda tensor: (
>         (tensor.cpu().numpy(),),
>         dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad),
>     ),
>     # (metadata, children) -> tensor
>     unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
>     namespace='torch.torch2numpy',
> )
> ```
>
> ```python
> >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
> >>> tree
> {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
>
> # Flatten without specifying the namespace
> >>> tree_flatten(tree)  # `torch.Tensor`s are leaf nodes
> ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
>
> # Flatten with the namespace
> >>> leaves, treespec = optree.tree_flatten(tree, namespace='torch.torch2numpy')
> >>> leaves, treespec
> (
>     [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
>     PyTreeSpec(
>         {
>             'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]),
>             'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*])
>         },
>         namespace='torch.torch2numpy'
>     )
> )
>
> # `entries` are not defined and use `range(len(children))`
> >>> optree.tree_paths(tree, namespace='torch.torch2numpy')
> [('bias', 0), ('weight', 0)]
>
> # Unflatten back to a copy of the original object
> >>> optree.tree_unflatten(treespec, leaves)
> {'bias': tensor([0., 0.]), 'weight': tensor([[1., 1.]], device='cuda:0')}
> ```
>
> Check out section [Registering a Container-like Custom Type as Non-leaf Nodes](https://github.com/metaopt/optree#notes-about-the-pytree-type-registry) for more details.
>
> ### 3. Support both `None` as Non-leaf Node and `None` as Leaf
>
> In JAX's implementation, `None` is always an internal non-leaf node with an arity 0, which is like an empty tuple. This limits the usage of the JAX's pytree utilities for PyTorch. For example, the `nn.Module` uses `_parameters` and `_buffers` (`OrderedDict[str, Optional[Tensor]]`) to hold the tensors, while the value can be a tensor or `None`.
>
> `optree` supports both `None` as Non-leaf Node (JAX's default) and `None` as Leaf (PyTorch's default). Check out section [None is Non-leaf Node vs. None is Leaf](https://github.com/metaopt/optree#none-is-non-leaf-node-vs-none-is-leaf) for more details.
>
> ### 4. Some other improvements and bug fixes
>
> 1. Adds in-place version of treemap (`tree_map_`), which reduces redundant unflatten operation for better performance.
> 2. Adds support for tree flatten and tree map with paths. (useful for `functorch` module extraction).
> 3. Improves the JAX's pytree sorting support for `dict`s.
> 4. Better string representation `repr(PyTreeSpec)`.
> 5. Fixes some bugs for JAX's pytree of hashing, pickle serialization, segmentation fault for infinite recursion, and tree-compose/tree-transpose.

From https://github.com/pytorch/pytorch/pull/92679#issuecomment-1398778481:

> ```python
> # pytree_make_fx_bench.py
> import torch
> from torch.fx.experimental.proxy_tensor import make_fx
> import time
>
> def f(x):
>     for _ in range(10000):
>         x = x+x
>     return x
>
> import time
> begin = time.time()
> out = make_fx(f, tracing_mode="real")(torch.randn(20))
> begin = time.time()
> print(f'tracing_mode="real" {time.time() - begin:.2f}')
> out = make_fx(f, tracing_mode="fake")(torch.randn(20))
> print(f'tracing_mode="fake" {time.time() - begin:.2f}')
>
> out = make_fx(f, tracing_mode="symbolic")(torch.randn(20))
> print(f'tracing_mode="symbolic" {time.time() - begin:.2f}')
> ```
>
> This seems to run around 10-20% faster with the optree implementation:
>
> ```
> # Optree
> python pytree_make_fx_bench.py
> tracing_mode="real" 0.00
> tracing_mode="fake" 6.32
> tracing_mode="symbolic" 27.13
> ```
>
> ```
> # torch.utils._pytree
> python pytree_make_fx_bench.py
> tracing_mode="real" 0.00
> tracing_mode="fake" 7.66
> tracing_mode="symbolic" 31.07
> ```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93139
Approved by: https://github.com/malfet
This commit is contained in:
Xuehai Pan
2023-09-18 21:24:56 +00:00
committed by PyTorch MergeBot
parent 8a567bb59d
commit 0bf30c140a
12 changed files with 1367 additions and 184 deletions

View File

@ -124,6 +124,19 @@ opt-einsum==3.3
#Pinned versions: 3.3
#test that import: test_linalg.py
optree==0.9.1
#Description: A library for tree manipulation
#Pinned versions: 0.9.1
#test that import: test_vmap.py, test_aotdispatch.py, test_dynamic_shapes.py,
#test_pytree.py, test_ops.py, test_control_flow.py, test_modules.py,
#common_utils.py, test_eager_transforms.py, test_python_dispatch.py,
#test_expanded_weights.py, test_decomp.py, test_overrides.py, test_masked.py,
#test_ops.py, test_prims.py, test_subclass.py, test_functionalization.py,
#test_schema_check.py, test_profiler_tree.py, test_meta.py, test_torchxla_num_output.py,
#test_utils.py, test_proxy_tensor.py, test_memory_profiler.py, test_view_ops.py,
#test_pointwise_ops.py, test_dtensor_ops.py, test_torchinductor.py, test_fx.py,
#test_fake_tensor.py, test_mps.py
pillow==9.3.0 ; python_version <= "3.8"
pillow==9.5.0 ; python_version > "3.8"
#Description: Python Imaging Library fork

View File

@ -1,3 +1,4 @@
# iOS simulator requirements
coremltools==5.0b5
protobuf==3.20.2
optree==0.9.1

View File

@ -26,3 +26,4 @@ pytest-cpp==2.3.0
rockset==1.0.3
z3-solver==4.12.2.0
tensorboard==2.13.0
optree==0.9.1

View File

@ -161,8 +161,8 @@ jobs:
PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }}
run: |
pushd "${PYTORCH_FINAL_PACKAGE_DIR}"
# shellcheck disable=SC2046
python3 -mpip install $(echo *.whl)[opt-einsum]
# shellcheck disable=SC2046,SC2102
python3 -mpip install $(echo *.whl)[opt-einsum,optree]
popd
.ci/pytorch/win-test.sh

View File

@ -171,6 +171,7 @@ init_command = [
'junitparser==2.1.1',
'rich==10.9.0',
'pyyaml==6.0',
'optree==0.9.1',
]
[[linter]]
@ -231,6 +232,7 @@ include_patterns = [
'tools/**/*.py',
'torchgen/**/*.py',
'torch/utils/_pytree.py',
'torch/utils/pytree.py',
'torch/utils/benchmark/utils/common.py',
'torch/utils/benchmark/utils/timer.py',
'torch/utils/benchmark/utils/valgrind_wrapper/**/*.py',

View File

@ -43,6 +43,7 @@ files =
tools,
torch/profiler/_memory_profiler.py,
torch/utils/_pytree.py,
torch/utils/pytree.py,
torch/utils/benchmark/utils/common.py,
torch/utils/benchmark/utils/timer.py,
torch/utils/benchmark/utils/valgrind_wrapper

View File

@ -17,3 +17,4 @@ fsspec
# setuptools was removed from default python install
setuptools ; python_version >= "3.12"
packaging
optree>=0.9.1

View File

@ -1175,6 +1175,7 @@ def main():
install_requires += extra_install_requires
extras_require = {
"optree": ["optree>=0.9.1"],
"opt-einsum": ["opt-einsum>=3.3"],
}
# Triton is only available on Linux atm

View File

@ -1,322 +1,389 @@
# Owner(s): ["module: pytree"]
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.utils._pytree import (
tree_flatten,
tree_map,
tree_unflatten,
TreeSpec,
LeafSpec,
treespec_dumps,
treespec_loads,
_register_pytree_node,
)
import pickle
import unittest
from torch.utils._pytree import _broadcast_to_and_flatten, tree_map_only, tree_all
from torch.utils._pytree import tree_any, tree_all_only, tree_any_only
from collections import namedtuple, OrderedDict
from torch.testing._internal.common_utils import parametrize, subtest, instantiate_parametrized_tests, TEST_WITH_TORCHDYNAMO
import torch
import torch.utils._pytree as py_pytree
import torch.utils.pytree as cxx_pytree
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
subtest,
TEST_WITH_TORCHDYNAMO,
TestCase,
)
GlobalPoint = namedtuple("GlobalPoint", ["x", "y"])
class GlobalDummyType:
def __init__(self, x, y):
self.x = x
self.y = y
class TestPytree(TestCase):
def test_treespec_equality(self):
self.assertTrue(LeafSpec() == LeafSpec())
self.assertTrue(TreeSpec(list, None, []) == TreeSpec(list, None, []))
self.assertTrue(TreeSpec(list, None, [LeafSpec()]) == TreeSpec(list, None, [LeafSpec()]))
self.assertFalse(TreeSpec(tuple, None, []) == TreeSpec(list, None, []))
self.assertTrue(TreeSpec(tuple, None, []) != TreeSpec(list, None, []))
self.assertTrue(
py_pytree.LeafSpec() == py_pytree.LeafSpec(),
)
self.assertTrue(
py_pytree.TreeSpec(list, None, []) == py_pytree.TreeSpec(list, None, []),
)
self.assertTrue(
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()])
== py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
)
self.assertFalse(
py_pytree.TreeSpec(tuple, None, []) == py_pytree.TreeSpec(list, None, []),
)
self.assertTrue(
py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []),
)
def test_flatten_unflatten_leaf(self):
def run_test_with_leaf(leaf):
values, treespec = tree_flatten(leaf)
values, treespec = py_pytree.tree_flatten(leaf)
self.assertEqual(values, [leaf])
self.assertEqual(treespec, LeafSpec())
self.assertEqual(treespec, py_pytree.LeafSpec())
unflattened = tree_unflatten(values, treespec)
unflattened = py_pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, leaf)
run_test_with_leaf(1)
run_test_with_leaf(1.)
run_test_with_leaf(1.0)
run_test_with_leaf(None)
run_test_with_leaf(bool)
run_test_with_leaf(torch.randn(3, 3))
def test_flatten_unflatten_list(self):
def run_test(lst):
expected_spec = TreeSpec(list, None, [LeafSpec() for _ in lst])
values, treespec = tree_flatten(lst)
expected_spec = py_pytree.TreeSpec(
list, None, [py_pytree.LeafSpec() for _ in lst]
)
values, treespec = py_pytree.tree_flatten(lst)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, lst)
self.assertEqual(treespec, expected_spec)
unflattened = tree_unflatten(values, treespec)
unflattened = py_pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, lst)
self.assertTrue(isinstance(unflattened, list))
run_test([])
run_test([1., 2])
run_test([torch.tensor([1., 2]), 2, 10, 9, 11])
run_test([1.0, 2])
run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])
def test_flatten_unflatten_tuple(self):
def run_test(tup):
expected_spec = TreeSpec(tuple, None, [LeafSpec() for _ in tup])
values, treespec = tree_flatten(tup)
expected_spec = py_pytree.TreeSpec(
tuple, None, [py_pytree.LeafSpec() for _ in tup]
)
values, treespec = py_pytree.tree_flatten(tup)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(tup))
self.assertEqual(treespec, expected_spec)
unflattened = tree_unflatten(values, treespec)
unflattened = py_pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, tup)
self.assertTrue(isinstance(unflattened, tuple))
run_test(())
run_test((1.,))
run_test((1., 2))
run_test((torch.tensor([1., 2]), 2, 10, 9, 11))
run_test((1.0,))
run_test((1.0, 2))
run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))
def test_flatten_unflatten_odict(self):
def run_test(odict):
expected_spec = TreeSpec(
expected_spec = py_pytree.TreeSpec(
OrderedDict,
list(odict.keys()),
[LeafSpec() for _ in odict.values()])
values, treespec = tree_flatten(odict)
[py_pytree.LeafSpec() for _ in odict.values()],
)
values, treespec = py_pytree.tree_flatten(odict)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(odict.values()))
self.assertEqual(treespec, expected_spec)
unflattened = tree_unflatten(values, treespec)
unflattened = py_pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, odict)
self.assertTrue(isinstance(unflattened, OrderedDict))
od = OrderedDict()
run_test(od)
od['b'] = 1
od['a'] = torch.tensor(3.14)
od["b"] = 1
od["a"] = torch.tensor(3.14)
run_test(od)
def test_flatten_unflatten_namedtuple(self):
Point = namedtuple('Point', ['x', 'y'])
Point = namedtuple("Point", ["x", "y"])
def run_test(tup):
expected_spec = TreeSpec(namedtuple, Point, [LeafSpec() for _ in tup])
values, treespec = tree_flatten(tup)
expected_spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec() for _ in tup]
)
values, treespec = py_pytree.tree_flatten(tup)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(tup))
self.assertEqual(treespec, expected_spec)
unflattened = tree_unflatten(values, treespec)
unflattened = py_pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, tup)
self.assertTrue(isinstance(unflattened, Point))
run_test(Point(1., 2))
run_test(Point(torch.tensor(1.), 2))
run_test(Point(1.0, 2))
run_test(Point(torch.tensor(1.0), 2))
@parametrize("op", [
subtest(torch.max, name='max'),
subtest(torch.min, name='min'),
])
@parametrize(
"op",
[
subtest(torch.max, name="max"),
subtest(torch.min, name="min"),
],
)
def test_flatten_unflatten_return_type(self, op):
x = torch.randn(3, 3)
expected = op(x, dim=0)
values, spec = tree_flatten(expected)
values, spec = py_pytree.tree_flatten(expected)
# Check that values is actually List[Tensor] and not (ReturnType(...),)
for value in values:
self.assertTrue(isinstance(value, torch.Tensor))
result = tree_unflatten(values, spec)
result = py_pytree.tree_unflatten(values, spec)
self.assertEqual(type(result), type(expected))
self.assertEqual(result, expected)
def test_flatten_unflatten_dict(self):
def run_test(tup):
expected_spec = TreeSpec(dict, list(tup.keys()),
[LeafSpec() for _ in tup.values()])
values, treespec = tree_flatten(tup)
def run_test(dct):
expected_spec = py_pytree.TreeSpec(
dict, list(dct.keys()), [py_pytree.LeafSpec() for _ in dct.values()]
)
values, treespec = py_pytree.tree_flatten(dct)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(tup.values()))
self.assertEqual(values, list(dct.values()))
self.assertEqual(treespec, expected_spec)
unflattened = tree_unflatten(values, treespec)
self.assertEqual(unflattened, tup)
unflattened = py_pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, dct)
self.assertTrue(isinstance(unflattened, dict))
run_test({})
run_test({'a': 1})
run_test({'abcdefg': torch.randn(2, 3)})
run_test({"a": 1})
run_test({"abcdefg": torch.randn(2, 3)})
run_test({1: torch.randn(2, 3)})
run_test({'a': 1, 'b': 2, 'c': torch.randn(2, 3)})
run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)})
def test_flatten_unflatten_nested(self):
def run_test(pytree):
values, treespec = tree_flatten(pytree)
values, treespec = py_pytree.tree_flatten(pytree)
self.assertTrue(isinstance(values, list))
self.assertEqual(len(values), treespec.num_leaves)
# NB: python basic data structures (dict list tuple) all have
# contents equality defined on them, so the following works for them.
unflattened = tree_unflatten(values, treespec)
unflattened = py_pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, pytree)
cases = [
[()],
([],),
{'a': ()},
{'a': 0, 'b': [{'c': 1}]},
{'a': 0, 'b': [1, {'c': 2}, torch.randn(3)], 'c': (torch.randn(2, 3), 1)},
{"a": ()},
{"a": 0, "b": [{"c": 1}]},
{"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)},
]
for case in cases:
run_test(case)
def test_treemap(self):
def run_test(pytree):
def f(x):
return x * 3
sm1 = sum(map(tree_flatten(pytree)[0], f))
sm2 = tree_flatten(tree_map(f, pytree))[0]
sm1 = sum(map(f, py_pytree.tree_flatten(pytree)[0]))
sm2 = sum(py_pytree.tree_flatten(py_pytree.tree_map(f, pytree))[0])
self.assertEqual(sm1, sm2)
def invf(x):
return x // 3
self.assertEqual(tree_flatten(tree_flatten(pytree, f), invf), pytree)
cases = [
[()],
([],),
{'a': ()},
{'a': 1, 'b': [{'c': 2}]},
{'a': 0, 'b': [2, {'c': 3}, 4], 'c': (5, 6)},
]
for case in cases:
run_test(case)
self.assertEqual(
py_pytree.tree_map(invf, py_pytree.tree_map(f, pytree)),
pytree,
)
cases = [
[()],
([],),
{"a": ()},
{"a": 1, "b": [{"c": 2}]},
{"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
]
for case in cases:
run_test(case)
def test_tree_only(self):
self.assertEqual(tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"])
self.assertEqual(
py_pytree.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
)
def test_tree_all_any(self):
self.assertTrue(tree_all(lambda x: x % 2, [1, 3]))
self.assertFalse(tree_all(lambda x: x % 2, [0, 1]))
self.assertTrue(tree_any(lambda x: x % 2, [0, 1]))
self.assertFalse(tree_any(lambda x: x % 2, [0, 2]))
self.assertTrue(tree_all_only(int, lambda x: x % 2, [1, 3, "a"]))
self.assertFalse(tree_all_only(int, lambda x: x % 2, [0, 1, "a"]))
self.assertTrue(tree_any_only(int, lambda x: x % 2, [0, 1, "a"]))
self.assertFalse(tree_any_only(int, lambda x: x % 2, [0, 2, "a"]))
self.assertTrue(py_pytree.tree_all(lambda x: x % 2, [1, 3]))
self.assertFalse(py_pytree.tree_all(lambda x: x % 2, [0, 1]))
self.assertTrue(py_pytree.tree_any(lambda x: x % 2, [0, 1]))
self.assertFalse(py_pytree.tree_any(lambda x: x % 2, [0, 2]))
self.assertTrue(py_pytree.tree_all_only(int, lambda x: x % 2, [1, 3, "a"]))
self.assertFalse(py_pytree.tree_all_only(int, lambda x: x % 2, [0, 1, "a"]))
self.assertTrue(py_pytree.tree_any_only(int, lambda x: x % 2, [0, 1, "a"]))
self.assertFalse(py_pytree.tree_any_only(int, lambda x: x % 2, [0, 2, "a"]))
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
def test_treespec_repr(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = tree_flatten(pytree)
self.assertEqual(repr(spec), ("TreeSpec(tuple, None, [*,\n"
" TreeSpec(list, None, [*,\n"
" *,\n"
" TreeSpec(list, None, [*])])])"))
_, spec = py_pytree.tree_flatten(pytree)
self.assertEqual(
repr(spec),
(
"TreeSpec(tuple, None, [*,\n"
" TreeSpec(list, None, [*,\n"
" *,\n"
" TreeSpec(list, None, [*])])])"
),
)
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
def test_treespec_repr_dynamo(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = tree_flatten(pytree)
self.assertExpectedInline(repr(spec),
"""\
_, spec = py_pytree.tree_flatten(pytree)
self.assertExpectedInline(
repr(spec),
"""\
TreeSpec(TupleVariable, None, [*,
TreeSpec(ListVariable, None, [*,
*,
TreeSpec(ListVariable, None, [*])])])""")
TreeSpec(ListVariable, None, [*])])])""",
)
def test_broadcast_to_and_flatten(self):
cases = [
(1, (), []),
# Same (flat) structures
((1,), (0,), [1]),
([1], [0], [1]),
((1, 2, 3), (0, 0, 0), [1, 2, 3]),
({'a': 1, 'b': 2}, {'a': 0, 'b': 0}, [1, 2]),
({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]),
# Mismatched (flat) structures
([1], (0,), None),
([1], (0,), None),
((1,), [0], None),
((1, 2, 3), (0, 0), None),
({'a': 1, 'b': 2}, {'a': 0}, None),
({'a': 1, 'b': 2}, {'a': 0, 'c': 0}, None),
({'a': 1, 'b': 2}, {'a': 0, 'b': 0, 'c': 0}, None),
({"a": 1, "b": 2}, {"a": 0}, None),
({"a": 1, "b": 2}, {"a": 0, "c": 0}, None),
({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None),
# Same (nested) structures
((1, [2, 3]), (0, [0, 0]), [1, 2, 3]),
((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]),
# Mismatched (nested) structures
((1, [2, 3]), (0, (0, 0)), None),
((1, [2, 3]), (0, [0, 0, 0]), None),
# Broadcasting single value
(1, (0, 0, 0), [1, 1, 1]),
(1, [0, 0, 0], [1, 1, 1]),
(1, {'a': 0, 'b': 0}, [1, 1]),
(1, {"a": 0, "b": 0}, [1, 1]),
(1, (0, [0, [0]], 0), [1, 1, 1, 1]),
(1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]),
# Broadcast multiple things
((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]),
((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]),
(([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
]
for pytree, to_pytree, expected in cases:
_, to_spec = tree_flatten(to_pytree)
result = _broadcast_to_and_flatten(pytree, to_spec)
_, to_spec = py_pytree.tree_flatten(to_pytree)
result = py_pytree._broadcast_to_and_flatten(pytree, to_spec)
self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
@parametrize("spec", [
TreeSpec(list, None, []),
TreeSpec(tuple, None, []),
TreeSpec(dict, [], []),
TreeSpec(list, None, [LeafSpec()]),
TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
TreeSpec(tuple, None, [LeafSpec(), LeafSpec(), LeafSpec()]),
TreeSpec(dict, ['a', 'b', 'c'], [LeafSpec(), LeafSpec(), LeafSpec()]),
TreeSpec(OrderedDict, ['a', 'b', 'c'], [
TreeSpec(
@parametrize(
"spec",
[
py_pytree.TreeSpec(list, None, []),
py_pytree.TreeSpec(tuple, None, []),
py_pytree.TreeSpec(dict, [], []),
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
py_pytree.TreeSpec(
list, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
),
py_pytree.TreeSpec(
tuple,
None,
[LeafSpec(), LeafSpec()]
[py_pytree.LeafSpec(), py_pytree.LeafSpec(), py_pytree.LeafSpec()],
),
LeafSpec(),
TreeSpec(
py_pytree.TreeSpec(
dict,
['a', 'b', 'c'],
[LeafSpec(), LeafSpec(), LeafSpec()]
["a", "b", "c"],
[py_pytree.LeafSpec(), py_pytree.LeafSpec(), py_pytree.LeafSpec()],
),
]),
TreeSpec(list, None, [
TreeSpec(tuple, None, [
LeafSpec(),
LeafSpec(),
TreeSpec(list, None, [
LeafSpec(),
LeafSpec(),
]),
]),
]),
],)
py_pytree.TreeSpec(
OrderedDict,
["a", "b", "c"],
[
py_pytree.TreeSpec(
tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
),
py_pytree.LeafSpec(),
py_pytree.TreeSpec(
dict,
["a", "b", "c"],
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
],
),
py_pytree.TreeSpec(
list,
None,
[
py_pytree.TreeSpec(
tuple,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.TreeSpec(
list,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
],
),
],
),
],
)
def test_pytree_serialize(self, spec):
serialized_spec = treespec_dumps(spec)
serialized_spec = py_pytree.treespec_dumps(spec)
self.assertTrue(isinstance(serialized_spec, str))
self.assertTrue(spec == treespec_loads(serialized_spec))
self.assertTrue(spec == py_pytree.treespec_loads(serialized_spec))
def test_pytree_serialize_namedtuple(self):
Point = namedtuple("Point", ["x", "y"])
spec = TreeSpec(namedtuple, Point, [LeafSpec(), LeafSpec()])
spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
roundtrip_spec = treespec_loads(treespec_dumps(spec))
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
# The context in the namedtuple is different now because we recreated
# the namedtuple type.
self.assertEqual(spec.context._fields, roundtrip_spec.context._fields)
@ -327,17 +394,19 @@ TreeSpec(TupleVariable, None, [*,
self.x = x
self.y = y
_register_pytree_node(
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: Dummy(*xs),
lambda xs, _: DummyType(*xs),
to_dumpable_context=lambda context: "moo",
from_dumpable_context=lambda dumpable_context: None,
)
spec = TreeSpec(DummyType, None, [LeafSpec(), LeafSpec()])
serialized_spec = treespec_dumps(spec, 1)
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
serialized_spec = py_pytree.treespec_dumps(spec, 1)
self.assertTrue("moo" in serialized_spec)
roundtrip_spec = treespec_loads(serialized_spec)
roundtrip_spec = py_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
def test_pytree_serialize_register_bad(self):
@ -346,11 +415,13 @@ TreeSpec(TupleVariable, None, [*,
self.x = x
self.y = y
with self.assertRaisesRegex(ValueError, "Both to_dumpable_context and from_dumpable_context"):
_register_pytree_node(
with self.assertRaisesRegex(
ValueError, "Both to_dumpable_context and from_dumpable_context"
):
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: Dummy(*xs),
lambda xs, _: DummyType(*xs),
to_dumpable_context=lambda context: "moo",
)
@ -360,55 +431,67 @@ TreeSpec(TupleVariable, None, [*,
self.x = x
self.y = y
_register_pytree_node(
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: Dummy(*xs),
lambda xs, _: DummyType(*xs),
to_dumpable_context=lambda context: DummyType,
from_dumpable_context=lambda dumpable_context: None,
)
spec = TreeSpec(DummyType, None, [LeafSpec(), LeafSpec()])
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
with self.assertRaisesRegex(TypeError, "Object of type type is not JSON serializable"):
treespec_dumps(spec)
with self.assertRaisesRegex(
TypeError, "Object of type type is not JSON serializable"
):
py_pytree.treespec_dumps(spec)
def test_pytree_serialize_bad_input(self):
with self.assertRaises(AttributeError):
treespec_dumps("random_blurb")
py_pytree.treespec_dumps("random_blurb")
def test_pytree_serialize_bad_protocol(self):
import json
Point = namedtuple("Point", ["x", "y"])
spec = TreeSpec(namedtuple, Point, [LeafSpec(), LeafSpec()])
spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
with self.assertRaisesRegex(ValueError, "Unknown protocol"):
treespec_dumps(spec, -1)
py_pytree.treespec_dumps(spec, -1)
serialized_spec = treespec_dumps(spec)
serialized_spec = py_pytree.treespec_dumps(spec)
protocol, data = json.loads(serialized_spec)
bad_protocol_serialized_spec = json.dumps((-1, data))
with self.assertRaisesRegex(ValueError, "Unknown protocol"):
treespec_loads(bad_protocol_serialized_spec)
py_pytree.treespec_loads(bad_protocol_serialized_spec)
def test_saved_serialized(self):
complicated_spec = TreeSpec(OrderedDict, [1, 2, 3], [
TreeSpec(
tuple,
None,
[LeafSpec(), LeafSpec()]
),
LeafSpec(),
TreeSpec(
dict,
[4, 5, 6],
[LeafSpec(), LeafSpec(), LeafSpec()]
),
])
complicated_spec = py_pytree.TreeSpec(
OrderedDict,
[1, 2, 3],
[
py_pytree.TreeSpec(
tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
),
py_pytree.LeafSpec(),
py_pytree.TreeSpec(
dict,
[4, 5, 6],
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
],
)
serialized_spec = treespec_dumps(complicated_spec)
serialized_spec = py_pytree.treespec_dumps(complicated_spec)
saved_spec = (
'[1, {"type": "collections.OrderedDict", "context": "[1, 2, 3]", '
'"children_spec": [{"type": "builtins.tuple", "context": "null", '
@ -421,10 +504,321 @@ TreeSpec(TupleVariable, None, [*,
'[]}, {"type": null, "context": null, "children_spec": []}]}]}]'
)
self.assertEqual(serialized_spec, saved_spec)
self.assertEqual(complicated_spec, treespec_loads(saved_spec))
self.assertEqual(complicated_spec, py_pytree.treespec_loads(saved_spec))
class TestCxxPytree(TestCase):
def test_treespec_equality(self):
self.assertTrue(cxx_pytree.LeafSpec() == cxx_pytree.LeafSpec())
def test_flatten_unflatten_leaf(self):
def run_test_with_leaf(leaf):
values, treespec = cxx_pytree.tree_flatten(leaf)
self.assertEqual(values, [leaf])
self.assertEqual(treespec, cxx_pytree.LeafSpec())
unflattened = cxx_pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, leaf)
run_test_with_leaf(1)
run_test_with_leaf(1.0)
run_test_with_leaf(None)
run_test_with_leaf(bool)
run_test_with_leaf(torch.randn(3, 3))
def test_flatten_unflatten_list(self):
def run_test(lst):
expected_spec = cxx_pytree.tree_structure([0] * len(lst))
values, treespec = cxx_pytree.tree_flatten(lst)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, lst)
self.assertEqual(treespec, expected_spec)
unflattened = cxx_pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, lst)
self.assertTrue(isinstance(unflattened, list))
run_test([])
run_test([1.0, 2])
run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])
def test_flatten_unflatten_tuple(self):
def run_test(tup):
expected_spec = cxx_pytree.tree_structure((0,) * len(tup))
values, treespec = cxx_pytree.tree_flatten(tup)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(tup))
self.assertEqual(treespec, expected_spec)
unflattened = cxx_pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, tup)
self.assertTrue(isinstance(unflattened, tuple))
run_test(())
run_test((1.0,))
run_test((1.0, 2))
run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))
def test_flatten_unflatten_odict(self):
def run_test(odict):
expected_spec = cxx_pytree.tree_structure(OrderedDict.fromkeys(odict, 0))
values, treespec = cxx_pytree.tree_flatten(odict)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(odict.values()))
self.assertEqual(treespec, expected_spec)
unflattened = cxx_pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, odict)
self.assertTrue(isinstance(unflattened, OrderedDict))
od = OrderedDict()
run_test(od)
od["b"] = 1
od["a"] = torch.tensor(3.14)
run_test(od)
def test_flatten_unflatten_namedtuple(self):
Point = namedtuple("Point", ["x", "y"])
def run_test(tup):
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
values, treespec = cxx_pytree.tree_flatten(tup)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(tup))
self.assertEqual(treespec, expected_spec)
unflattened = cxx_pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, tup)
self.assertTrue(isinstance(unflattened, Point))
run_test(Point(1.0, 2))
run_test(Point(torch.tensor(1.0), 2))
@parametrize(
"op",
[
subtest(torch.max, name="max"),
subtest(torch.min, name="min"),
],
)
def test_flatten_unflatten_return_type(self, op):
x = torch.randn(3, 3)
expected = op(x, dim=0)
values, spec = cxx_pytree.tree_flatten(expected)
# Check that values is actually List[Tensor] and not (ReturnType(...),)
for value in values:
self.assertTrue(isinstance(value, torch.Tensor))
result = cxx_pytree.tree_unflatten(values, spec)
self.assertEqual(type(result), type(expected))
self.assertEqual(result, expected)
def test_flatten_unflatten_dict(self):
def run_test(dct):
expected_spec = cxx_pytree.tree_structure(dict.fromkeys(dct, 0))
values, treespec = cxx_pytree.tree_flatten(dct)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(dct.values()))
self.assertEqual(treespec, expected_spec)
unflattened = cxx_pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, dct)
self.assertTrue(isinstance(unflattened, dict))
run_test({})
run_test({"a": 1})
run_test({"abcdefg": torch.randn(2, 3)})
run_test({1: torch.randn(2, 3)})
run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)})
def test_flatten_unflatten_nested(self):
def run_test(pytree):
values, treespec = cxx_pytree.tree_flatten(pytree)
self.assertTrue(isinstance(values, list))
self.assertEqual(len(values), treespec.num_leaves)
# NB: python basic data structures (dict list tuple) all have
# contents equality defined on them, so the following works for them.
unflattened = cxx_pytree.tree_unflatten(values, treespec)
self.assertEqual(unflattened, pytree)
cases = [
[()],
([],),
{"a": ()},
{"a": 0, "b": [{"c": 1}]},
{"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)},
]
for case in cases:
run_test(case)
def test_treemap(self):
def run_test(pytree):
def f(x):
return x * 3
sm1 = sum(map(f, cxx_pytree.tree_flatten(pytree)[0]))
sm2 = sum(cxx_pytree.tree_flatten(cxx_pytree.tree_map(f, pytree))[0])
self.assertEqual(sm1, sm2)
def invf(x):
return x // 3
self.assertEqual(
cxx_pytree.tree_map(invf, cxx_pytree.tree_map(f, pytree)),
pytree,
)
cases = [
[()],
([],),
{"a": ()},
{"a": 1, "b": [{"c": 2}]},
{"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
]
for case in cases:
run_test(case)
def test_tree_only(self):
self.assertEqual(
cxx_pytree.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
)
def test_tree_all_any(self):
self.assertTrue(cxx_pytree.tree_all(lambda x: x % 2, [1, 3]))
self.assertFalse(cxx_pytree.tree_all(lambda x: x % 2, [0, 1]))
self.assertTrue(cxx_pytree.tree_any(lambda x: x % 2, [0, 1]))
self.assertFalse(cxx_pytree.tree_any(lambda x: x % 2, [0, 2]))
self.assertTrue(cxx_pytree.tree_all_only(int, lambda x: x % 2, [1, 3, "a"]))
self.assertFalse(cxx_pytree.tree_all_only(int, lambda x: x % 2, [0, 1, "a"]))
self.assertTrue(cxx_pytree.tree_any_only(int, lambda x: x % 2, [0, 1, "a"]))
self.assertFalse(cxx_pytree.tree_any_only(int, lambda x: x % 2, [0, 2, "a"]))
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
def test_treespec_repr(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = cxx_pytree.tree_flatten(pytree)
self.assertEqual(
repr(spec),
("PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)"),
)
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
def test_treespec_repr_dynamo(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = cxx_pytree.tree_flatten(pytree)
self.assertExpectedInline(
repr(spec),
"PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)",
)
def test_broadcast_to_and_flatten(self):
cases = [
(1, (), []),
# Same (flat) structures
((1,), (0,), [1]),
([1], [0], [1]),
((1, 2, 3), (0, 0, 0), [1, 2, 3]),
({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]),
# Mismatched (flat) structures
([1], (0,), None),
([1], (0,), None),
((1,), [0], None),
((1, 2, 3), (0, 0), None),
({"a": 1, "b": 2}, {"a": 0}, None),
({"a": 1, "b": 2}, {"a": 0, "c": 0}, None),
({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None),
# Same (nested) structures
((1, [2, 3]), (0, [0, 0]), [1, 2, 3]),
((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]),
# Mismatched (nested) structures
((1, [2, 3]), (0, (0, 0)), None),
((1, [2, 3]), (0, [0, 0, 0]), None),
# Broadcasting single value
(1, (0, 0, 0), [1, 1, 1]),
(1, [0, 0, 0], [1, 1, 1]),
(1, {"a": 0, "b": 0}, [1, 1]),
(1, (0, [0, [0]], 0), [1, 1, 1, 1]),
(1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]),
# Broadcast multiple things
((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]),
((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]),
(([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
]
for pytree, to_pytree, expected in cases:
_, to_spec = cxx_pytree.tree_flatten(to_pytree)
result = cxx_pytree._broadcast_to_and_flatten(pytree, to_spec)
self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
@parametrize(
"spec",
[
cxx_pytree.tree_structure([]),
cxx_pytree.tree_structure(()),
cxx_pytree.tree_structure({}),
cxx_pytree.tree_structure([0]),
cxx_pytree.tree_structure([0, 1]),
cxx_pytree.tree_structure((0, 1, 2)),
cxx_pytree.tree_structure({"a": 0, "b": 1, "c": 2}),
cxx_pytree.tree_structure(
OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
),
cxx_pytree.tree_structure([(0, 1, [2, 3])]),
],
)
def test_pytree_serialize(self, spec):
serialized_spec = cxx_pytree.treespec_dumps(spec)
self.assertTrue(isinstance(serialized_spec, bytes))
self.assertTrue(spec == cxx_pytree.treespec_loads(serialized_spec))
def test_pytree_serialize_namedtuple(self):
spec = cxx_pytree.tree_structure(GlobalPoint(0, 1))
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
self.assertEqual(roundtrip_spec, spec)
LocalPoint = namedtuple("LocalPoint", ["x", "y"])
spec = cxx_pytree.tree_structure(LocalPoint(0, 1))
with self.assertRaises(pickle.PicklingError):
cxx_pytree.treespec_dumps(spec)
def test_pytree_custom_type_serialize(self):
cxx_pytree.register_pytree_node(
GlobalDummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: GlobalDummyType(*xs),
)
spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1))
serialized_spec = cxx_pytree.treespec_dumps(spec)
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
class LocalDummyType:
def __init__(self, x, y):
self.x = x
self.y = y
cxx_pytree.register_pytree_node(
LocalDummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: LocalDummyType(*xs),
)
spec = cxx_pytree.tree_structure(LocalDummyType(0, 1))
with self.assertRaises(AttributeError):
serialized_spec = cxx_pytree.treespec_dumps(spec)
def test_pytree_serialize_bad_input(self):
with self.assertRaises(TypeError):
cxx_pytree.treespec_dumps("random_blurb")
instantiate_parametrized_tests(TestPytree)
instantiate_parametrized_tests(TestCxxPytree)
if __name__ == '__main__':
if __name__ == "__main__":
run_tests()

View File

@ -263,7 +263,7 @@ def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree:
This is the inverse operation of `tree_flatten`.
"""
if not isinstance(spec, TreeSpec):
raise ValueError(
raise TypeError(
f'tree_unflatten(values, spec): Expected `spec` to be instance of '
f'TreeSpec but got item of type {type(spec)}.')
if len(values) != spec.num_leaves:

View File

@ -109,6 +109,7 @@ def get_conda_packages(run_lambda):
"mkl",
"magma",
"triton",
"optree",
}
)
)
@ -389,6 +390,7 @@ def get_pip_packages(run_lambda):
"mypy",
"flake8",
"triton",
"optree",
}
)
)

767
torch/utils/pytree.py Normal file
View File

@ -0,0 +1,767 @@
"""
Contains utility functions for working with nested python data structures.
A *pytree* is Python nested data structure. It is a tree in the sense that
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
Python values. Furthermore, a pytree should not contain reference cycles.
pytrees are useful for working with nested collections of Tensors. For example,
one can use `tree_map` to map a function over all Tensors inside some nested
collection of Tensors and `tree_leaves` to get a flat list of all Tensors
inside some nested collection. pytrees are helpful for implementing nested
collection support for PyTorch APIs.
"""
import functools
import pickle
from typing import (
Any,
Callable,
Iterable,
List,
Optional,
overload,
Tuple,
Type,
TypeVar,
Union,
)
import optree
from optree import PyTreeSpec # direct import for type annotations
__all__ = [
"PyTree",
"PyTreeSpec",
"register_pytree_node",
"tree_flatten",
"tree_unflatten",
"tree_leaves",
"tree_structure",
"tree_map",
"tree_map_",
"tree_map_only",
"tree_map_only_",
"tree_all",
"tree_any",
"tree_all_only",
"tree_any_only",
"broadcast_prefix",
"_broadcast_to_and_flatten",
"treespec_dumps",
"treespec_loads",
]
T = TypeVar("T")
S = TypeVar("S")
R = TypeVar("R")
Context = Optional[Any]
PyTree = Any
TreeSpec = PyTreeSpec
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
UnflattenFunc = Callable[[Iterable, Context], PyTree]
OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree]
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
@functools.wraps(func)
def wrapped(*args: Any, **kwargs: Any) -> Any:
return func(*reversed(args), **kwargs)
return wrapped
def register_pytree_node(
cls: Type[Any],
flatten_func: FlattenFunc,
unflatten_func: UnflattenFunc,
namespace: str = "torch",
) -> None:
"""Extend the set of types that are considered internal nodes in pytrees.
The ``namespace`` argument is used to avoid collisions that occur when different libraries
register the same Python type with different behaviors. It is recommended to add a unique prefix
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
the same class in different namespaces for different use cases.
.. warning::
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
prevent accidental collisions between different libraries that may register the same type.
Args:
cls (type): A Python type to treat as an internal pytree node.
flatten_fn (callable): A function to be used during flattening, taking an instance of ``cls``
and returning a triple or optionally a pair, with (1) an iterable for the children to be
flattened recursively, and (2) some hashable auxiliary data to be stored in the treespec
and to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree
path entries to the corresponding children. If the entries are not provided or given by
:data:`None`, then `range(len(children))` will be used.
unflatten_fn (callable): A function taking two arguments: the auxiliary data that was returned
by ``flatten_func`` and stored in the treespec, and the unflattened children. The function
should return an instance of ``cls``.
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
type registry. This is used to isolate the registry from other modules that might register
a different custom behavior for the same type. (default: :const:`"torch"`)
Example::
>>> # Registry a Python type with lambda functions
>>> register_pytree_node(
... set,
... lambda s: (sorted(s), None, None),
... lambda children, _: set(children),
... namespace='set',
... )
>>> # Register a Python type into a namespace
>>> import torch
>>> register_pytree_node(
... torch.Tensor,
... flatten_func=lambda tensor: (
... (tensor.cpu().numpy(),),
... dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad),
... ),
... unflatten_func=lambda children, metadata: torch.tensor(children[0], **metadata),
... namespace='torch2numpy',
... )
>>> # xdoctest: +SKIP
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
>>> tree
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
>>> # Flatten without specifying the namespace
>>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes # xdoctest: +SKIP
([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
>>> # Flatten with the namespace
>>> tree_flatten(tree, namespace='torch2numpy') # xdoctest: +SKIP
(
[array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
PyTreeSpec(
{
'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]),
'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*])
},
namespace='torch2numpy'
)
)
>>> # Register the same type with a different namespace for different behaviors
>>> def tensor2flatparam(tensor):
... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
...
>>> def flatparam2tensor(children, metadata):
... return children[0].reshape(metadata)
...
>>> register_pytree_node(
... torch.Tensor,
... flatten_func=tensor2flatparam,
... unflatten_func=flatparam2tensor,
... namespace='tensor2flatparam',
... )
>>> # Flatten with the new namespace
>>> tree_flatten(tree, namespace='tensor2flatparam') # xdoctest: +SKIP
(
[
Parameter containing: tensor([0., 0.], requires_grad=True),
Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True)
],
PyTreeSpec(
{
'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]),
'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*])
},
namespace='tensor2flatparam'
)
)
"""
optree.register_pytree_node(
cls, flatten_func, _reverse_args(unflatten_func), namespace=namespace
)
def tree_flatten(
tree: PyTree,
*,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> Tuple[List[Any], PyTreeSpec]:
"""Flatten a pytree.
See also :func:`tree_unflatten`.
The flattening order (i.e., the order of elements in the output list) is deterministic,
corresponding to a left-to-right depth-first tree traversal.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree_flatten(tree)
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
>>> tree_flatten(tree, none_is_leaf=False)
([1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}))
>>> tree_flatten(1)
([1], PyTreeSpec(*, NoneIsLeaf))
>>> tree_flatten(None)
([None], PyTreeSpec(*, NoneIsLeaf))
>>> tree_flatten(None, none_is_leaf=False)
([], PyTreeSpec(None))
For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
if you want to keep the keys in the insertion order.
>>> from collections import OrderedDict
>>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)])
>>> tree_flatten(tree)
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf))
>>> tree_flatten(tree, none_is_leaf=False)
([2, 3, 4, 1, 5], PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', None), ('d', *)])))
Args:
tree (pytree): A pytree to flatten.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list. (default: :data:`True`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`"torch"`)
Returns:
A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the
second element is a treespec representing the structure of the pytree.
"""
return optree.tree_flatten(tree, none_is_leaf=none_is_leaf, namespace=namespace)
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
"""Reconstruct a pytree from the treespec and the leaves.
The inverse of :func:`tree_flatten`.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> leaves, treespec = tree_flatten(tree)
>>> tree == tree_unflatten(leaves, treespec)
True
Args:
leaves (iterable): The list of leaves to use for reconstruction. The list must match the
number of leaves of the treespec.
treespec (PyTreeSpec): The treespec to reconstruct.
Returns:
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
``treespec``.
"""
if not isinstance(treespec, PyTreeSpec):
raise TypeError(
f"tree_unflatten(values, spec): Expected `spec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return optree.tree_unflatten(treespec, leaves)
def tree_leaves(
tree: PyTree,
*,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> List[Any]:
"""Get the leaves of a pytree.
See also :func:`tree_flatten`.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree_leaves(tree)
[1, 2, 3, 4, None, 5]
>>> tree_leaves(tree, none_is_leaf=False)
[1, 2, 3, 4, 5]
>>> tree_leaves(1)
[1]
>>> tree_leaves(None)
[None]
>>> tree_leaves(None, none_is_leaf=False)
[]
Args:
tree (pytree): A pytree to flatten.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list. (default: :data:`True`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`"torch"`)
Returns:
A list of leaf values.
"""
return optree.tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace)
def tree_structure(
tree: PyTree,
*,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> PyTreeSpec:
"""Get the treespec for a pytree.
See also :func:`tree_flatten`.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree_structure(tree)
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
>>> tree_structure(tree, none_is_leaf=False)
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
>>> tree_structure(1)
PyTreeSpec(*, NoneIsLeaf)
>>> tree_structure(None)
PyTreeSpec(*, NoneIsLeaf)
>>> tree_structure(None, none_is_leaf=False)
PyTreeSpec(None)
Args:
tree (pytree): A pytree to flatten.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list. (default: :data:`True`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`"torch"`)
Returns:
A treespec object representing the structure of the pytree.
"""
return optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace)
def tree_map(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> PyTree:
"""Map a multi-input function over pytree args to produce a new pytree.
See also :func:`tree_map_`.
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
{'x': 8, 'y': (43, 65)}
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
{'x': False, 'y': (False, False), 'z': True}
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64), 'z': None}, none_is_leaf=False)
{'x': 8, 'y': (43, 65), 'z': None}
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}, none_is_leaf=False)
{'x': False, 'y': (False, False), 'z': None}
If multiple inputs are given, the structure of the tree is taken from the first input;
subsequent inputs need only have ``tree`` as a prefix:
>>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
[[5, 7, 9], [6, 1, 2]]
Args:
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
corresponding leaves of the pytrees.
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
argument to function ``func``.
rests (tuple of pytrees): A tuple of pytrees, each of which has the same structure as
``tree`` or has ``tree`` as a prefix.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list. (default: :data:`True`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`"torch"`)
Returns:
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
is the tuple of values at corresponding nodes in ``rests``.
"""
return optree.tree_map(
func, tree, *rests, none_is_leaf=none_is_leaf, namespace=namespace
)
def tree_map_(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> PyTree:
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
See also :func:`tree_map`.
Args:
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
corresponding leaves of the pytrees.
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
argument to function ``func``.
rests (tuple of pytrees): A tuple of pytrees, each of which has the same structure as
``tree`` or has ``tree`` as a prefix.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list. (default: :data:`True`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`"torch"`)
Returns:
The original ``tree`` with the value at each leaf is given by the side-effect of function
``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
"""
return optree.tree_map_(
func, tree, *rests, none_is_leaf=none_is_leaf, namespace=namespace
)
Type2 = Tuple[Type[T], Type[S]]
TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
Fn2 = Callable[[Union[T, S]], R]
Fn = Callable[[T], R]
FnAny = Callable[[Any], R]
MapOnlyFn = Callable[[T], Callable[[Any], Any]]
# These specializations help with type inference on the lambda passed to this
# function
@overload
def map_only(__type_or_types: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
...
@overload
def map_only(__type_or_types: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
...
# This specialization is needed for the implementations below that call
@overload
def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
...
def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
"""
Suppose you are writing a tree_map over tensors, leaving everything
else unchanged. Ordinarily you would have to write:
def go(t):
if isinstance(t, Tensor):
return ...
else:
return t
With this function, you only need to write:
@map_only(Tensor)
def go(t):
return ...
You can also directly use 'tree_map_only'
"""
def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
@functools.wraps(func)
def wrapped(x: T) -> Any:
if isinstance(x, __type_or_types):
return func(x)
return x
return wrapped
return wrapper
@overload
def tree_map_only(
__type_or_types: Type[T],
func: Fn[T, Any],
tree: PyTree,
*rests: PyTree,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> PyTree:
...
@overload
def tree_map_only(
__type_or_types: Type2[T, S],
func: Fn2[T, S, Any],
tree: PyTree,
*rests: PyTree,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> PyTree:
...
def tree_map_only(
__type_or_types: TypeAny,
func: FnAny[Any],
tree: PyTree,
*rests: PyTree,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> PyTree:
return tree_map(
map_only(__type_or_types)(func),
tree,
*rests,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@overload
def tree_map_only_(
__type_or_types: Type[T],
func: Fn[T, Any],
tree: PyTree,
*rests: PyTree,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> PyTree:
...
@overload
def tree_map_only_(
__type_or_types: Type2[T, S],
func: Fn2[T, S, Any],
tree: PyTree,
*rests: PyTree,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> PyTree:
...
def tree_map_only_(
__type_or_types: TypeAny,
func: FnAny[Any],
tree: PyTree,
*rests: PyTree,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> PyTree:
return tree_map_(
map_only(__type_or_types)(func),
tree,
*rests,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
def tree_all(
pred: Callable[[Any], bool],
tree: PyTree,
*,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> bool:
flat_args = tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace)
return all(map(pred, flat_args))
def tree_any(
pred: Callable[[Any], bool],
tree: PyTree,
*,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> bool:
flat_args = tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace)
return any(map(pred, flat_args))
@overload
def tree_all_only(
__type_or_types: Type[T],
pred: Fn[T, bool],
tree: PyTree,
*,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> bool:
...
@overload
def tree_all_only(
__type_or_types: Type2[T, S],
pred: Fn2[T, S, bool],
tree: PyTree,
*,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> bool:
...
def tree_all_only(
__type_or_types: TypeAny,
pred: FnAny[bool],
tree: PyTree,
*,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> bool:
flat_args = tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace)
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
@overload
def tree_any_only(
__type_or_types: Type[T],
pred: Fn[T, bool],
tree: PyTree,
*,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> bool:
...
@overload
def tree_any_only(
__type_or_types: Type2[T, S],
pred: Fn2[T, S, bool],
tree: PyTree,
*,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> bool:
...
def tree_any_only(
__type_or_types: TypeAny,
pred: FnAny[bool],
tree: PyTree,
*,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> bool:
flat_args = tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace)
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
def broadcast_prefix(
prefix_tree: PyTree,
full_tree: PyTree,
*,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> List[Any]:
"""Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``.
If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be
constructed by replacing the leaves of ``prefix_tree`` with appropriate **subtrees**.
This function returns a list of leaves with the same size as ``full_tree``. The leaves are
replicated from ``prefix_tree``. The number of replicas is determined by the corresponding
subtree in ``full_tree``.
>>> broadcast_prefix(1, [1, 2, 3])
[1, 1, 1]
>>> broadcast_prefix([1, 2, 3], [1, 2, 3])
[1, 2, 3]
>>> broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
Traceback (most recent call last):
...
ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].
>>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)])
[1, 2, 3, 3]
>>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}])
[1, 2, 3, 3, 3, 3]
>>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], none_is_leaf=False)
[1, 2, 3, 3, 3]
Args:
prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``.
full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``.
is_leaf (callable, optional): An optionally specified function that will be called at each
flattening step. It should return a boolean, with :data:`True` stopping the traversal
and the whole subtree being treated as a leaf, and :data:`False` indicating the
flattening should traverse the current object.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list. (default: :data:`True`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`"torch"`)
Returns:
A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
"""
return optree.broadcast_prefix(
prefix_tree, full_tree, none_is_leaf=none_is_leaf, namespace=namespace
)
# Broadcasts a pytree to the provided TreeSpec and returns the flattened
# values. If this is not possible, then this function returns None.
#
# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
# would return [0, 0]. This is useful for part of the vmap implementation:
# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
# broadcastable to the tree structure of `inputs` and we use
# _broadcast_to_and_flatten to check this.
def _broadcast_to_and_flatten(
tree: PyTree,
treespec: PyTreeSpec,
*,
none_is_leaf: bool = True,
namespace: str = "torch",
) -> Optional[List[Any]]:
assert isinstance(treespec, PyTreeSpec)
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
try:
return broadcast_prefix(
tree, full_tree, none_is_leaf=none_is_leaf, namespace=namespace
)
except ValueError:
return None
def treespec_dumps(treespec: PyTreeSpec) -> bytes:
"""Serialize a treespec to bytes."""
if not isinstance(treespec, PyTreeSpec):
raise TypeError(
f"treespec_dumps(spec): Expected `spec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return pickle.dumps(treespec)
def treespec_loads(serialized: bytes) -> PyTreeSpec:
"""Deserialize a treespec from bytes."""
treespec = pickle.loads(serialized)
if not isinstance(treespec, PyTreeSpec):
raise TypeError(
f"treespec_loads(serialized): Expected to return an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return treespec
class PyTreeLeafSpecMeta(type(optree.PyTreeSpec)): # type: ignore[misc]
def __instancecheck__(self, instance: object) -> bool:
return isinstance(instance, optree.PyTreeSpec) and instance.is_leaf()
class PyTreeLeafSpec(optree.PyTreeSpec, metaclass=PyTreeLeafSpecMeta):
def __new__(cls, none_is_leaf: bool = True) -> "PyTreeLeafSpec":
return optree.treespec_leaf(none_is_leaf=none_is_leaf) # type: ignore[return-value]
LeafSpec = PyTreeLeafSpec