mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytree] Extract reusable generic tests for pytree (#110395)
Part of #109684 - #109684 Changes: - Add new functions `tree_structure`, `tree_leaves`, `tree_map_` and `tree_map_only_` to Python pytree. - Extract reusable tests for pytree to `TestGenericPytree`. - Change `treespec_dumps` and `treespec_loads` in C++ pytree to call Python pytree and use JSON string as serialization type. - Rename `torch.utils.pytree` -> `torch.utils._cxx_pytree`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110395 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
37afa0c349
commit
449271f3f1
@ -223,7 +223,7 @@ include_patterns = [
|
||||
'tools/**/*.py',
|
||||
'torchgen/**/*.py',
|
||||
'torch/utils/_pytree.py',
|
||||
'torch/utils/pytree.py',
|
||||
'torch/utils/_cxx_pytree.py',
|
||||
'torch/utils/benchmark/utils/common.py',
|
||||
'torch/utils/benchmark/utils/timer.py',
|
||||
'torch/utils/benchmark/utils/valgrind_wrapper/**/*.py',
|
||||
|
@ -43,7 +43,7 @@ files =
|
||||
tools,
|
||||
torch/profiler/_memory_profiler.py,
|
||||
torch/utils/_pytree.py,
|
||||
torch/utils/pytree.py,
|
||||
torch/utils/_cxx_pytree.py,
|
||||
torch/utils/benchmark/utils/common.py,
|
||||
torch/utils/benchmark/utils/timer.py,
|
||||
torch/utils/benchmark/utils/valgrind_wrapper
|
||||
|
@ -1,12 +1,11 @@
|
||||
# Owner(s): ["module: pytree"]
|
||||
|
||||
import pickle
|
||||
import unittest
|
||||
from collections import namedtuple, OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.utils._cxx_pytree as cxx_pytree
|
||||
import torch.utils._pytree as py_pytree
|
||||
import torch.utils.pytree as cxx_pytree
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -26,32 +25,21 @@ class GlobalDummyType:
|
||||
self.y = y
|
||||
|
||||
|
||||
class TestPytree(TestCase):
|
||||
def test_treespec_equality(self):
|
||||
self.assertTrue(
|
||||
py_pytree.LeafSpec() == py_pytree.LeafSpec(),
|
||||
)
|
||||
self.assertTrue(
|
||||
py_pytree.TreeSpec(list, None, []) == py_pytree.TreeSpec(list, None, []),
|
||||
)
|
||||
self.assertTrue(
|
||||
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()])
|
||||
== py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
|
||||
)
|
||||
self.assertFalse(
|
||||
py_pytree.TreeSpec(tuple, None, []) == py_pytree.TreeSpec(list, None, []),
|
||||
)
|
||||
self.assertTrue(
|
||||
py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []),
|
||||
)
|
||||
|
||||
def test_flatten_unflatten_leaf(self):
|
||||
class TestGenericPytree(TestCase):
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_flatten_unflatten_leaf(self, pytree_impl):
|
||||
def run_test_with_leaf(leaf):
|
||||
values, treespec = py_pytree.tree_flatten(leaf)
|
||||
values, treespec = pytree_impl.tree_flatten(leaf)
|
||||
self.assertEqual(values, [leaf])
|
||||
self.assertEqual(treespec, py_pytree.LeafSpec())
|
||||
self.assertEqual(treespec, pytree_impl.LeafSpec())
|
||||
|
||||
unflattened = py_pytree.tree_unflatten(values, treespec)
|
||||
unflattened = pytree_impl.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, leaf)
|
||||
|
||||
run_test_with_leaf(1)
|
||||
@ -60,17 +48,33 @@ class TestPytree(TestCase):
|
||||
run_test_with_leaf(bool)
|
||||
run_test_with_leaf(torch.randn(3, 3))
|
||||
|
||||
def test_flatten_unflatten_list(self):
|
||||
@parametrize(
|
||||
"pytree_impl,gen_expected_fn",
|
||||
[
|
||||
subtest(
|
||||
(
|
||||
py_pytree,
|
||||
lambda lst: py_pytree.TreeSpec(
|
||||
list, None, [py_pytree.LeafSpec() for _ in lst]
|
||||
),
|
||||
),
|
||||
name="py",
|
||||
),
|
||||
subtest(
|
||||
(cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))),
|
||||
name="cxx",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn):
|
||||
def run_test(lst):
|
||||
expected_spec = py_pytree.TreeSpec(
|
||||
list, None, [py_pytree.LeafSpec() for _ in lst]
|
||||
)
|
||||
values, treespec = py_pytree.tree_flatten(lst)
|
||||
expected_spec = gen_expected_fn(lst)
|
||||
values, treespec = pytree_impl.tree_flatten(lst)
|
||||
self.assertTrue(isinstance(values, list))
|
||||
self.assertEqual(values, lst)
|
||||
self.assertEqual(treespec, expected_spec)
|
||||
|
||||
unflattened = py_pytree.tree_unflatten(values, treespec)
|
||||
unflattened = pytree_impl.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, lst)
|
||||
self.assertTrue(isinstance(unflattened, list))
|
||||
|
||||
@ -78,17 +82,33 @@ class TestPytree(TestCase):
|
||||
run_test([1.0, 2])
|
||||
run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])
|
||||
|
||||
def test_flatten_unflatten_tuple(self):
|
||||
@parametrize(
|
||||
"pytree_impl,gen_expected_fn",
|
||||
[
|
||||
subtest(
|
||||
(
|
||||
py_pytree,
|
||||
lambda tup: py_pytree.TreeSpec(
|
||||
tuple, None, [py_pytree.LeafSpec() for _ in tup]
|
||||
),
|
||||
),
|
||||
name="py",
|
||||
),
|
||||
subtest(
|
||||
(cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))),
|
||||
name="cxx",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn):
|
||||
def run_test(tup):
|
||||
expected_spec = py_pytree.TreeSpec(
|
||||
tuple, None, [py_pytree.LeafSpec() for _ in tup]
|
||||
)
|
||||
values, treespec = py_pytree.tree_flatten(tup)
|
||||
expected_spec = gen_expected_fn(tup)
|
||||
values, treespec = pytree_impl.tree_flatten(tup)
|
||||
self.assertTrue(isinstance(values, list))
|
||||
self.assertEqual(values, list(tup))
|
||||
self.assertEqual(treespec, expected_spec)
|
||||
|
||||
unflattened = py_pytree.tree_unflatten(values, treespec)
|
||||
unflattened = pytree_impl.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, tup)
|
||||
self.assertTrue(isinstance(unflattened, tuple))
|
||||
|
||||
@ -97,19 +117,81 @@ class TestPytree(TestCase):
|
||||
run_test((1.0, 2))
|
||||
run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))
|
||||
|
||||
def test_flatten_unflatten_odict(self):
|
||||
@parametrize(
|
||||
"pytree_impl,gen_expected_fn",
|
||||
[
|
||||
subtest(
|
||||
(
|
||||
py_pytree,
|
||||
lambda dct: py_pytree.TreeSpec(
|
||||
dict,
|
||||
list(dct.keys()),
|
||||
[py_pytree.LeafSpec() for _ in dct.values()],
|
||||
),
|
||||
),
|
||||
name="py",
|
||||
),
|
||||
subtest(
|
||||
(
|
||||
cxx_pytree,
|
||||
lambda dct: cxx_pytree.tree_structure(dict.fromkeys(dct, 0)),
|
||||
),
|
||||
name="cxx",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_flatten_unflatten_dict(self, pytree_impl, gen_expected_fn):
|
||||
def run_test(dct):
|
||||
expected_spec = gen_expected_fn(dct)
|
||||
values, treespec = pytree_impl.tree_flatten(dct)
|
||||
self.assertTrue(isinstance(values, list))
|
||||
self.assertEqual(values, list(dct.values()))
|
||||
self.assertEqual(treespec, expected_spec)
|
||||
|
||||
unflattened = pytree_impl.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, dct)
|
||||
self.assertTrue(isinstance(unflattened, dict))
|
||||
|
||||
run_test({})
|
||||
run_test({"a": 1})
|
||||
run_test({"abcdefg": torch.randn(2, 3)})
|
||||
run_test({1: torch.randn(2, 3)})
|
||||
run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)})
|
||||
|
||||
@parametrize(
|
||||
"pytree_impl,gen_expected_fn",
|
||||
[
|
||||
subtest(
|
||||
(
|
||||
py_pytree,
|
||||
lambda odict: py_pytree.TreeSpec(
|
||||
OrderedDict,
|
||||
list(odict.keys()),
|
||||
[py_pytree.LeafSpec() for _ in odict.values()],
|
||||
),
|
||||
),
|
||||
name="py",
|
||||
),
|
||||
subtest(
|
||||
(
|
||||
cxx_pytree,
|
||||
lambda odict: cxx_pytree.tree_structure(
|
||||
OrderedDict.fromkeys(odict, 0)
|
||||
),
|
||||
),
|
||||
name="cxx",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_flatten_unflatten_odict(self, pytree_impl, gen_expected_fn):
|
||||
def run_test(odict):
|
||||
expected_spec = py_pytree.TreeSpec(
|
||||
OrderedDict,
|
||||
list(odict.keys()),
|
||||
[py_pytree.LeafSpec() for _ in odict.values()],
|
||||
)
|
||||
values, treespec = py_pytree.tree_flatten(odict)
|
||||
expected_spec = gen_expected_fn(odict)
|
||||
values, treespec = pytree_impl.tree_flatten(odict)
|
||||
self.assertTrue(isinstance(values, list))
|
||||
self.assertEqual(values, list(odict.values()))
|
||||
self.assertEqual(treespec, expected_spec)
|
||||
|
||||
unflattened = py_pytree.tree_unflatten(values, treespec)
|
||||
unflattened = pytree_impl.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, odict)
|
||||
self.assertTrue(isinstance(unflattened, OrderedDict))
|
||||
|
||||
@ -120,19 +202,29 @@ class TestPytree(TestCase):
|
||||
od["a"] = torch.tensor(3.14)
|
||||
run_test(od)
|
||||
|
||||
def test_flatten_unflatten_namedtuple(self):
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_flatten_unflatten_namedtuple(self, pytree_impl):
|
||||
Point = namedtuple("Point", ["x", "y"])
|
||||
|
||||
def run_test(tup):
|
||||
expected_spec = py_pytree.TreeSpec(
|
||||
namedtuple, Point, [py_pytree.LeafSpec() for _ in tup]
|
||||
)
|
||||
values, treespec = py_pytree.tree_flatten(tup)
|
||||
if pytree_impl is py_pytree:
|
||||
expected_spec = py_pytree.TreeSpec(
|
||||
namedtuple, Point, [py_pytree.LeafSpec() for _ in tup]
|
||||
)
|
||||
else:
|
||||
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
|
||||
values, treespec = pytree_impl.tree_flatten(tup)
|
||||
self.assertTrue(isinstance(values, list))
|
||||
self.assertEqual(values, list(tup))
|
||||
self.assertEqual(treespec, expected_spec)
|
||||
|
||||
unflattened = py_pytree.tree_unflatten(values, treespec)
|
||||
unflattened = pytree_impl.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, tup)
|
||||
self.assertTrue(isinstance(unflattened, Point))
|
||||
|
||||
@ -146,48 +238,42 @@ class TestPytree(TestCase):
|
||||
subtest(torch.min, name="min"),
|
||||
],
|
||||
)
|
||||
def test_flatten_unflatten_return_type(self, op):
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_flatten_unflatten_return_type(self, pytree_impl, op):
|
||||
x = torch.randn(3, 3)
|
||||
expected = op(x, dim=0)
|
||||
|
||||
values, spec = py_pytree.tree_flatten(expected)
|
||||
values, spec = pytree_impl.tree_flatten(expected)
|
||||
# Check that values is actually List[Tensor] and not (ReturnType(...),)
|
||||
for value in values:
|
||||
self.assertTrue(isinstance(value, torch.Tensor))
|
||||
result = py_pytree.tree_unflatten(values, spec)
|
||||
result = pytree_impl.tree_unflatten(values, spec)
|
||||
|
||||
self.assertEqual(type(result), type(expected))
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_flatten_unflatten_dict(self):
|
||||
def run_test(dct):
|
||||
expected_spec = py_pytree.TreeSpec(
|
||||
dict, list(dct.keys()), [py_pytree.LeafSpec() for _ in dct.values()]
|
||||
)
|
||||
values, treespec = py_pytree.tree_flatten(dct)
|
||||
self.assertTrue(isinstance(values, list))
|
||||
self.assertEqual(values, list(dct.values()))
|
||||
self.assertEqual(treespec, expected_spec)
|
||||
|
||||
unflattened = py_pytree.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, dct)
|
||||
self.assertTrue(isinstance(unflattened, dict))
|
||||
|
||||
run_test({})
|
||||
run_test({"a": 1})
|
||||
run_test({"abcdefg": torch.randn(2, 3)})
|
||||
run_test({1: torch.randn(2, 3)})
|
||||
run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)})
|
||||
|
||||
def test_flatten_unflatten_nested(self):
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_flatten_unflatten_nested(self, pytree_impl):
|
||||
def run_test(pytree):
|
||||
values, treespec = py_pytree.tree_flatten(pytree)
|
||||
values, treespec = pytree_impl.tree_flatten(pytree)
|
||||
self.assertTrue(isinstance(values, list))
|
||||
self.assertEqual(len(values), treespec.num_leaves)
|
||||
|
||||
# NB: python basic data structures (dict list tuple) all have
|
||||
# contents equality defined on them, so the following works for them.
|
||||
unflattened = py_pytree.tree_unflatten(values, treespec)
|
||||
unflattened = pytree_impl.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, pytree)
|
||||
|
||||
cases = [
|
||||
@ -200,20 +286,27 @@ class TestPytree(TestCase):
|
||||
for case in cases:
|
||||
run_test(case)
|
||||
|
||||
def test_treemap(self):
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_treemap(self, pytree_impl):
|
||||
def run_test(pytree):
|
||||
def f(x):
|
||||
return x * 3
|
||||
|
||||
sm1 = sum(map(f, py_pytree.tree_flatten(pytree)[0]))
|
||||
sm2 = sum(py_pytree.tree_flatten(py_pytree.tree_map(f, pytree))[0])
|
||||
sm1 = sum(map(f, pytree_impl.tree_flatten(pytree)[0]))
|
||||
sm2 = sum(pytree_impl.tree_flatten(pytree_impl.tree_map(f, pytree))[0])
|
||||
self.assertEqual(sm1, sm2)
|
||||
|
||||
def invf(x):
|
||||
return x // 3
|
||||
|
||||
self.assertEqual(
|
||||
py_pytree.tree_map(invf, py_pytree.tree_map(f, pytree)),
|
||||
pytree_impl.tree_map(invf, pytree_impl.tree_map(f, pytree)),
|
||||
pytree,
|
||||
)
|
||||
|
||||
@ -227,51 +320,43 @@ class TestPytree(TestCase):
|
||||
for case in cases:
|
||||
run_test(case)
|
||||
|
||||
def test_tree_only(self):
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_tree_only(self, pytree_impl):
|
||||
self.assertEqual(
|
||||
py_pytree.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
|
||||
pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
|
||||
)
|
||||
|
||||
def test_tree_all_any(self):
|
||||
self.assertTrue(py_pytree.tree_all(lambda x: x % 2, [1, 3]))
|
||||
self.assertFalse(py_pytree.tree_all(lambda x: x % 2, [0, 1]))
|
||||
self.assertTrue(py_pytree.tree_any(lambda x: x % 2, [0, 1]))
|
||||
self.assertFalse(py_pytree.tree_any(lambda x: x % 2, [0, 2]))
|
||||
self.assertTrue(py_pytree.tree_all_only(int, lambda x: x % 2, [1, 3, "a"]))
|
||||
self.assertFalse(py_pytree.tree_all_only(int, lambda x: x % 2, [0, 1, "a"]))
|
||||
self.assertTrue(py_pytree.tree_any_only(int, lambda x: x % 2, [0, 1, "a"]))
|
||||
self.assertFalse(py_pytree.tree_any_only(int, lambda x: x % 2, [0, 2, "a"]))
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_tree_all_any(self, pytree_impl):
|
||||
self.assertTrue(pytree_impl.tree_all(lambda x: x % 2, [1, 3]))
|
||||
self.assertFalse(pytree_impl.tree_all(lambda x: x % 2, [0, 1]))
|
||||
self.assertTrue(pytree_impl.tree_any(lambda x: x % 2, [0, 1]))
|
||||
self.assertFalse(pytree_impl.tree_any(lambda x: x % 2, [0, 2]))
|
||||
self.assertTrue(pytree_impl.tree_all_only(int, lambda x: x % 2, [1, 3, "a"]))
|
||||
self.assertFalse(pytree_impl.tree_all_only(int, lambda x: x % 2, [0, 1, "a"]))
|
||||
self.assertTrue(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 1, "a"]))
|
||||
self.assertFalse(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 2, "a"]))
|
||||
|
||||
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
|
||||
def test_treespec_repr(self):
|
||||
# Check that it looks sane
|
||||
pytree = (0, [0, 0, [0]])
|
||||
_, spec = py_pytree.tree_flatten(pytree)
|
||||
self.assertEqual(
|
||||
repr(spec),
|
||||
(
|
||||
"TreeSpec(tuple, None, [*,\n"
|
||||
" TreeSpec(list, None, [*,\n"
|
||||
" *,\n"
|
||||
" TreeSpec(list, None, [*])])])"
|
||||
),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
|
||||
def test_treespec_repr_dynamo(self):
|
||||
# Check that it looks sane
|
||||
pytree = (0, [0, 0, [0]])
|
||||
_, spec = py_pytree.tree_flatten(pytree)
|
||||
self.assertExpectedInline(
|
||||
repr(spec),
|
||||
"""\
|
||||
TreeSpec(tuple, None, [*,
|
||||
TreeSpec(list, None, [*,
|
||||
*,
|
||||
TreeSpec(list, None, [*])])])""",
|
||||
)
|
||||
|
||||
def test_broadcast_to_and_flatten(self):
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_broadcast_to_and_flatten(self, pytree_impl):
|
||||
cases = [
|
||||
(1, (), []),
|
||||
# Same (flat) structures
|
||||
@ -305,10 +390,70 @@ TreeSpec(tuple, None, [*,
|
||||
(([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
|
||||
]
|
||||
for pytree, to_pytree, expected in cases:
|
||||
_, to_spec = py_pytree.tree_flatten(to_pytree)
|
||||
result = py_pytree._broadcast_to_and_flatten(pytree, to_spec)
|
||||
_, to_spec = pytree_impl.tree_flatten(to_pytree)
|
||||
result = pytree_impl._broadcast_to_and_flatten(pytree, to_spec)
|
||||
self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
|
||||
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_pytree_serialize_bad_input(self, pytree_impl):
|
||||
with self.assertRaises(TypeError):
|
||||
pytree_impl.treespec_dumps("random_blurb")
|
||||
|
||||
|
||||
class TestPythonPytree(TestCase):
|
||||
def test_treespec_equality(self):
|
||||
self.assertTrue(
|
||||
py_pytree.LeafSpec() == py_pytree.LeafSpec(),
|
||||
)
|
||||
self.assertTrue(
|
||||
py_pytree.TreeSpec(list, None, []) == py_pytree.TreeSpec(list, None, []),
|
||||
)
|
||||
self.assertTrue(
|
||||
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()])
|
||||
== py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
|
||||
)
|
||||
self.assertFalse(
|
||||
py_pytree.TreeSpec(tuple, None, []) == py_pytree.TreeSpec(list, None, []),
|
||||
)
|
||||
self.assertTrue(
|
||||
py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []),
|
||||
)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
|
||||
def test_treespec_repr(self):
|
||||
# Check that it looks sane
|
||||
pytree = (0, [0, 0, [0]])
|
||||
_, spec = py_pytree.tree_flatten(pytree)
|
||||
self.assertEqual(
|
||||
repr(spec),
|
||||
(
|
||||
"TreeSpec(tuple, None, [*,\n"
|
||||
" TreeSpec(list, None, [*,\n"
|
||||
" *,\n"
|
||||
" TreeSpec(list, None, [*])])])"
|
||||
),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
|
||||
def test_treespec_repr_dynamo(self):
|
||||
# Check that it looks sane
|
||||
pytree = (0, [0, 0, [0]])
|
||||
_, spec = py_pytree.tree_flatten(pytree)
|
||||
self.assertExpectedInline(
|
||||
repr(spec),
|
||||
"""\
|
||||
TreeSpec(tuple, None, [*,
|
||||
TreeSpec(list, None, [*,
|
||||
*,
|
||||
TreeSpec(list, None, [*])])])""",
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"spec",
|
||||
[
|
||||
@ -448,10 +593,6 @@ TreeSpec(tuple, None, [*,
|
||||
):
|
||||
py_pytree.treespec_dumps(spec)
|
||||
|
||||
def test_pytree_serialize_bad_input(self):
|
||||
with self.assertRaises(AttributeError):
|
||||
py_pytree.treespec_dumps("random_blurb")
|
||||
|
||||
def test_pytree_serialize_bad_protocol(self):
|
||||
import json
|
||||
|
||||
@ -511,191 +652,6 @@ class TestCxxPytree(TestCase):
|
||||
def test_treespec_equality(self):
|
||||
self.assertTrue(cxx_pytree.LeafSpec() == cxx_pytree.LeafSpec())
|
||||
|
||||
def test_flatten_unflatten_leaf(self):
|
||||
def run_test_with_leaf(leaf):
|
||||
values, treespec = cxx_pytree.tree_flatten(leaf)
|
||||
self.assertEqual(values, [leaf])
|
||||
self.assertEqual(treespec, cxx_pytree.LeafSpec())
|
||||
|
||||
unflattened = cxx_pytree.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, leaf)
|
||||
|
||||
run_test_with_leaf(1)
|
||||
run_test_with_leaf(1.0)
|
||||
run_test_with_leaf(None)
|
||||
run_test_with_leaf(bool)
|
||||
run_test_with_leaf(torch.randn(3, 3))
|
||||
|
||||
def test_flatten_unflatten_list(self):
|
||||
def run_test(lst):
|
||||
expected_spec = cxx_pytree.tree_structure([0] * len(lst))
|
||||
values, treespec = cxx_pytree.tree_flatten(lst)
|
||||
self.assertTrue(isinstance(values, list))
|
||||
self.assertEqual(values, lst)
|
||||
self.assertEqual(treespec, expected_spec)
|
||||
|
||||
unflattened = cxx_pytree.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, lst)
|
||||
self.assertTrue(isinstance(unflattened, list))
|
||||
|
||||
run_test([])
|
||||
run_test([1.0, 2])
|
||||
run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])
|
||||
|
||||
def test_flatten_unflatten_tuple(self):
|
||||
def run_test(tup):
|
||||
expected_spec = cxx_pytree.tree_structure((0,) * len(tup))
|
||||
values, treespec = cxx_pytree.tree_flatten(tup)
|
||||
self.assertTrue(isinstance(values, list))
|
||||
self.assertEqual(values, list(tup))
|
||||
self.assertEqual(treespec, expected_spec)
|
||||
|
||||
unflattened = cxx_pytree.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, tup)
|
||||
self.assertTrue(isinstance(unflattened, tuple))
|
||||
|
||||
run_test(())
|
||||
run_test((1.0,))
|
||||
run_test((1.0, 2))
|
||||
run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))
|
||||
|
||||
def test_flatten_unflatten_odict(self):
|
||||
def run_test(odict):
|
||||
expected_spec = cxx_pytree.tree_structure(OrderedDict.fromkeys(odict, 0))
|
||||
values, treespec = cxx_pytree.tree_flatten(odict)
|
||||
self.assertTrue(isinstance(values, list))
|
||||
self.assertEqual(values, list(odict.values()))
|
||||
self.assertEqual(treespec, expected_spec)
|
||||
|
||||
unflattened = cxx_pytree.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, odict)
|
||||
self.assertTrue(isinstance(unflattened, OrderedDict))
|
||||
|
||||
od = OrderedDict()
|
||||
run_test(od)
|
||||
|
||||
od["b"] = 1
|
||||
od["a"] = torch.tensor(3.14)
|
||||
run_test(od)
|
||||
|
||||
def test_flatten_unflatten_namedtuple(self):
|
||||
Point = namedtuple("Point", ["x", "y"])
|
||||
|
||||
def run_test(tup):
|
||||
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
|
||||
values, treespec = cxx_pytree.tree_flatten(tup)
|
||||
self.assertTrue(isinstance(values, list))
|
||||
self.assertEqual(values, list(tup))
|
||||
self.assertEqual(treespec, expected_spec)
|
||||
|
||||
unflattened = cxx_pytree.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, tup)
|
||||
self.assertTrue(isinstance(unflattened, Point))
|
||||
|
||||
run_test(Point(1.0, 2))
|
||||
run_test(Point(torch.tensor(1.0), 2))
|
||||
|
||||
@parametrize(
|
||||
"op",
|
||||
[
|
||||
subtest(torch.max, name="max"),
|
||||
subtest(torch.min, name="min"),
|
||||
],
|
||||
)
|
||||
def test_flatten_unflatten_return_type(self, op):
|
||||
x = torch.randn(3, 3)
|
||||
expected = op(x, dim=0)
|
||||
|
||||
values, spec = cxx_pytree.tree_flatten(expected)
|
||||
# Check that values is actually List[Tensor] and not (ReturnType(...),)
|
||||
for value in values:
|
||||
self.assertTrue(isinstance(value, torch.Tensor))
|
||||
result = cxx_pytree.tree_unflatten(values, spec)
|
||||
|
||||
self.assertEqual(type(result), type(expected))
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_flatten_unflatten_dict(self):
|
||||
def run_test(dct):
|
||||
expected_spec = cxx_pytree.tree_structure(dict.fromkeys(dct, 0))
|
||||
values, treespec = cxx_pytree.tree_flatten(dct)
|
||||
self.assertTrue(isinstance(values, list))
|
||||
self.assertEqual(values, list(dct.values()))
|
||||
self.assertEqual(treespec, expected_spec)
|
||||
|
||||
unflattened = cxx_pytree.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, dct)
|
||||
self.assertTrue(isinstance(unflattened, dict))
|
||||
|
||||
run_test({})
|
||||
run_test({"a": 1})
|
||||
run_test({"abcdefg": torch.randn(2, 3)})
|
||||
run_test({1: torch.randn(2, 3)})
|
||||
run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)})
|
||||
|
||||
def test_flatten_unflatten_nested(self):
|
||||
def run_test(pytree):
|
||||
values, treespec = cxx_pytree.tree_flatten(pytree)
|
||||
self.assertTrue(isinstance(values, list))
|
||||
self.assertEqual(len(values), treespec.num_leaves)
|
||||
|
||||
# NB: python basic data structures (dict list tuple) all have
|
||||
# contents equality defined on them, so the following works for them.
|
||||
unflattened = cxx_pytree.tree_unflatten(values, treespec)
|
||||
self.assertEqual(unflattened, pytree)
|
||||
|
||||
cases = [
|
||||
[()],
|
||||
([],),
|
||||
{"a": ()},
|
||||
{"a": 0, "b": [{"c": 1}]},
|
||||
{"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)},
|
||||
]
|
||||
for case in cases:
|
||||
run_test(case)
|
||||
|
||||
def test_treemap(self):
|
||||
def run_test(pytree):
|
||||
def f(x):
|
||||
return x * 3
|
||||
|
||||
sm1 = sum(map(f, cxx_pytree.tree_flatten(pytree)[0]))
|
||||
sm2 = sum(cxx_pytree.tree_flatten(cxx_pytree.tree_map(f, pytree))[0])
|
||||
self.assertEqual(sm1, sm2)
|
||||
|
||||
def invf(x):
|
||||
return x // 3
|
||||
|
||||
self.assertEqual(
|
||||
cxx_pytree.tree_map(invf, cxx_pytree.tree_map(f, pytree)),
|
||||
pytree,
|
||||
)
|
||||
|
||||
cases = [
|
||||
[()],
|
||||
([],),
|
||||
{"a": ()},
|
||||
{"a": 1, "b": [{"c": 2}]},
|
||||
{"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
|
||||
]
|
||||
for case in cases:
|
||||
run_test(case)
|
||||
|
||||
def test_tree_only(self):
|
||||
self.assertEqual(
|
||||
cxx_pytree.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
|
||||
)
|
||||
|
||||
def test_tree_all_any(self):
|
||||
self.assertTrue(cxx_pytree.tree_all(lambda x: x % 2, [1, 3]))
|
||||
self.assertFalse(cxx_pytree.tree_all(lambda x: x % 2, [0, 1]))
|
||||
self.assertTrue(cxx_pytree.tree_any(lambda x: x % 2, [0, 1]))
|
||||
self.assertFalse(cxx_pytree.tree_any(lambda x: x % 2, [0, 2]))
|
||||
self.assertTrue(cxx_pytree.tree_all_only(int, lambda x: x % 2, [1, 3, "a"]))
|
||||
self.assertFalse(cxx_pytree.tree_all_only(int, lambda x: x % 2, [0, 1, "a"]))
|
||||
self.assertTrue(cxx_pytree.tree_any_only(int, lambda x: x % 2, [0, 1, "a"]))
|
||||
self.assertFalse(cxx_pytree.tree_any_only(int, lambda x: x % 2, [0, 2, "a"]))
|
||||
|
||||
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
|
||||
def test_treespec_repr(self):
|
||||
# Check that it looks sane
|
||||
@ -716,44 +672,6 @@ class TestCxxPytree(TestCase):
|
||||
"PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)",
|
||||
)
|
||||
|
||||
def test_broadcast_to_and_flatten(self):
|
||||
cases = [
|
||||
(1, (), []),
|
||||
# Same (flat) structures
|
||||
((1,), (0,), [1]),
|
||||
([1], [0], [1]),
|
||||
((1, 2, 3), (0, 0, 0), [1, 2, 3]),
|
||||
({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]),
|
||||
# Mismatched (flat) structures
|
||||
([1], (0,), None),
|
||||
([1], (0,), None),
|
||||
((1,), [0], None),
|
||||
((1, 2, 3), (0, 0), None),
|
||||
({"a": 1, "b": 2}, {"a": 0}, None),
|
||||
({"a": 1, "b": 2}, {"a": 0, "c": 0}, None),
|
||||
({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None),
|
||||
# Same (nested) structures
|
||||
((1, [2, 3]), (0, [0, 0]), [1, 2, 3]),
|
||||
((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]),
|
||||
# Mismatched (nested) structures
|
||||
((1, [2, 3]), (0, (0, 0)), None),
|
||||
((1, [2, 3]), (0, [0, 0, 0]), None),
|
||||
# Broadcasting single value
|
||||
(1, (0, 0, 0), [1, 1, 1]),
|
||||
(1, [0, 0, 0], [1, 1, 1]),
|
||||
(1, {"a": 0, "b": 0}, [1, 1]),
|
||||
(1, (0, [0, [0]], 0), [1, 1, 1, 1]),
|
||||
(1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]),
|
||||
# Broadcast multiple things
|
||||
((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]),
|
||||
((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]),
|
||||
(([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
|
||||
]
|
||||
for pytree, to_pytree, expected in cases:
|
||||
_, to_spec = cxx_pytree.tree_flatten(to_pytree)
|
||||
result = cxx_pytree._broadcast_to_and_flatten(pytree, to_spec)
|
||||
self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
|
||||
|
||||
@parametrize(
|
||||
"spec",
|
||||
[
|
||||
@ -772,20 +690,20 @@ class TestCxxPytree(TestCase):
|
||||
)
|
||||
def test_pytree_serialize(self, spec):
|
||||
serialized_spec = cxx_pytree.treespec_dumps(spec)
|
||||
self.assertTrue(isinstance(serialized_spec, bytes))
|
||||
self.assertTrue(isinstance(serialized_spec, str))
|
||||
self.assertTrue(spec == cxx_pytree.treespec_loads(serialized_spec))
|
||||
|
||||
def test_pytree_serialize_namedtuple(self):
|
||||
spec = cxx_pytree.tree_structure(GlobalPoint(0, 1))
|
||||
|
||||
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
|
||||
self.assertEqual(roundtrip_spec, spec)
|
||||
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
|
||||
|
||||
LocalPoint = namedtuple("LocalPoint", ["x", "y"])
|
||||
spec = cxx_pytree.tree_structure(LocalPoint(0, 1))
|
||||
|
||||
with self.assertRaises(pickle.PicklingError):
|
||||
cxx_pytree.treespec_dumps(spec)
|
||||
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
|
||||
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
|
||||
|
||||
def test_pytree_custom_type_serialize(self):
|
||||
cxx_pytree.register_pytree_node(
|
||||
@ -809,16 +727,15 @@ class TestCxxPytree(TestCase):
|
||||
lambda xs, _: LocalDummyType(*xs),
|
||||
)
|
||||
spec = cxx_pytree.tree_structure(LocalDummyType(0, 1))
|
||||
with self.assertRaises(AttributeError):
|
||||
serialized_spec = cxx_pytree.treespec_dumps(spec)
|
||||
|
||||
def test_pytree_serialize_bad_input(self):
|
||||
with self.assertRaises(TypeError):
|
||||
cxx_pytree.treespec_dumps("random_blurb")
|
||||
serialized_spec = cxx_pytree.treespec_dumps(spec)
|
||||
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
|
||||
self.assertEqual(roundtrip_spec, spec)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestPytree)
|
||||
instantiate_parametrized_tests(TestGenericPytree)
|
||||
instantiate_parametrized_tests(TestPythonPytree)
|
||||
instantiate_parametrized_tests(TestCxxPytree)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -6,9 +6,9 @@ import torch.utils._pytree as pytree
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_autograd_fallback_mode(mode):
|
||||
prev = torch._C._get_autograd_fallback_mode()
|
||||
try:
|
||||
prev = torch._C._get_autograd_fallback_mode()
|
||||
torch._C._set_autograd_fallback_mode("nothing")
|
||||
torch._C._set_autograd_fallback_mode(mode)
|
||||
yield
|
||||
finally:
|
||||
torch._C._set_autograd_fallback_mode(prev)
|
||||
@ -70,8 +70,9 @@ def autograd_registration_check(op, args, kwargs):
|
||||
# constructing true in-place or out variants), but we defer that
|
||||
# responsibility to a different test (schema_check).
|
||||
|
||||
all_args = (args, kwargs)
|
||||
if not pytree.tree_any_only(torch.Tensor, lambda x: x.requires_grad, all_args):
|
||||
flat_args = pytree.tree_leaves((args, kwargs))
|
||||
all_tensors = [arg for arg in flat_args if isinstance(arg, torch.Tensor)]
|
||||
if not any(t.requires_grad for t in all_tensors):
|
||||
raise RuntimeError(
|
||||
"autograd_registration_check: no inputs have requires_grad=True so "
|
||||
"we are unable to actually perform this test. Please pass inputs "
|
||||
@ -79,8 +80,6 @@ def autograd_registration_check(op, args, kwargs):
|
||||
)
|
||||
|
||||
# Determine which AutogradBACKEND key to check
|
||||
flat_args, _ = pytree.tree_flatten(all_args)
|
||||
all_tensors = [arg for arg in flat_args if isinstance(arg, torch.Tensor)]
|
||||
all_device_types = {arg.device.type for arg in all_tensors}
|
||||
if not all_device_types.issubset(["cpu", "cuda"]):
|
||||
# Don't want to support other keys yet
|
||||
@ -106,7 +105,7 @@ def autograd_registration_check(op, args, kwargs):
|
||||
with set_autograd_fallback_mode("nothing"):
|
||||
all_outs = op(*args, **kwargs)
|
||||
|
||||
inp_ids = set({id(arg) for arg in flat_args})
|
||||
inp_ids = {id(arg) for arg in flat_args}
|
||||
|
||||
def not_an_input_and_requires_grad(tensor):
|
||||
if not tensor.requires_grad:
|
||||
|
@ -13,7 +13,6 @@ collection support for PyTorch APIs.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import pickle
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -110,6 +109,7 @@ def register_pytree_node(
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Registry a Python type with lambda functions
|
||||
>>> register_pytree_node(
|
||||
... set,
|
||||
@ -118,27 +118,30 @@ def register_pytree_node(
|
||||
... namespace='set',
|
||||
... )
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Register a Python type into a namespace
|
||||
>>> import torch
|
||||
>>> register_pytree_node(
|
||||
... torch.Tensor,
|
||||
... flatten_func=lambda tensor: (
|
||||
... (tensor.cpu().numpy(),),
|
||||
... dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad),
|
||||
... (tensor.cpu().detach().numpy(),),
|
||||
... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
|
||||
... ),
|
||||
... unflatten_func=lambda children, metadata: torch.tensor(children[0], **metadata),
|
||||
... namespace='torch2numpy',
|
||||
... )
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
|
||||
>>> tree
|
||||
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
|
||||
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> # Flatten without specifying the namespace
|
||||
>>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes # xdoctest: +SKIP
|
||||
([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Flatten with the namespace
|
||||
>>> tree_flatten(tree, namespace='torch2numpy') # xdoctest: +SKIP
|
||||
(
|
||||
@ -152,6 +155,7 @@ def register_pytree_node(
|
||||
)
|
||||
)
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Register the same type with a different namespace for different behaviors
|
||||
>>> def tensor2flatparam(tensor):
|
||||
... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
|
||||
@ -166,6 +170,7 @@ def register_pytree_node(
|
||||
... namespace='tensor2flatparam',
|
||||
... )
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Flatten with the new namespace
|
||||
>>> tree_flatten(tree, namespace='tensor2flatparam') # xdoctest: +SKIP
|
||||
(
|
||||
@ -182,8 +187,19 @@ def register_pytree_node(
|
||||
)
|
||||
)
|
||||
"""
|
||||
from ._pytree import _register_pytree_node
|
||||
|
||||
_register_pytree_node(
|
||||
cls,
|
||||
flatten_func,
|
||||
unflatten_func,
|
||||
)
|
||||
|
||||
optree.register_pytree_node(
|
||||
cls, flatten_func, _reverse_args(unflatten_func), namespace=namespace
|
||||
cls,
|
||||
flatten_func,
|
||||
_reverse_args(unflatten_func),
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
@ -235,7 +251,11 @@ def tree_flatten(
|
||||
A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the
|
||||
second element is a treespec representing the structure of the pytree.
|
||||
"""
|
||||
return optree.tree_flatten(tree, none_is_leaf=none_is_leaf, namespace=namespace)
|
||||
return optree.tree_flatten( # type: ignore[return-value]
|
||||
tree,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
|
||||
@ -262,7 +282,7 @@ def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
|
||||
f"tree_unflatten(values, spec): Expected `spec` to be instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
return optree.tree_unflatten(treespec, leaves)
|
||||
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def tree_leaves(
|
||||
@ -334,7 +354,11 @@ def tree_structure(
|
||||
Returns:
|
||||
A treespec object representing the structure of the pytree.
|
||||
"""
|
||||
return optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace)
|
||||
return optree.tree_structure( # type: ignore[return-value]
|
||||
tree,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
def tree_map(
|
||||
@ -382,7 +406,11 @@ def tree_map(
|
||||
is the tuple of values at corresponding nodes in ``rests``.
|
||||
"""
|
||||
return optree.tree_map(
|
||||
func, tree, *rests, none_is_leaf=none_is_leaf, namespace=namespace
|
||||
func,
|
||||
tree,
|
||||
*rests,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
@ -416,7 +444,11 @@ def tree_map_(
|
||||
in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
|
||||
"""
|
||||
return optree.tree_map_(
|
||||
func, tree, *rests, none_is_leaf=none_is_leaf, namespace=namespace
|
||||
func,
|
||||
tree,
|
||||
*rests,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
@ -704,7 +736,10 @@ def broadcast_prefix(
|
||||
A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
|
||||
"""
|
||||
return optree.broadcast_prefix(
|
||||
prefix_tree, full_tree, none_is_leaf=none_is_leaf, namespace=namespace
|
||||
prefix_tree,
|
||||
full_tree,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
@ -727,39 +762,50 @@ def _broadcast_to_and_flatten(
|
||||
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
|
||||
try:
|
||||
return broadcast_prefix(
|
||||
tree, full_tree, none_is_leaf=none_is_leaf, namespace=namespace
|
||||
tree,
|
||||
full_tree,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def treespec_dumps(treespec: PyTreeSpec) -> bytes:
|
||||
"""Serialize a treespec to bytes."""
|
||||
def treespec_dumps(treespec: PyTreeSpec) -> str:
|
||||
"""Serialize a treespec to a JSON string."""
|
||||
if not isinstance(treespec, PyTreeSpec):
|
||||
raise TypeError(
|
||||
f"treespec_dumps(spec): Expected `spec` to be instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
return pickle.dumps(treespec)
|
||||
from ._pytree import (
|
||||
tree_structure as _tree_structure,
|
||||
treespec_dumps as _treespec_dumps,
|
||||
)
|
||||
|
||||
orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec))
|
||||
return _treespec_dumps(orig_treespec)
|
||||
|
||||
|
||||
def treespec_loads(serialized: bytes) -> PyTreeSpec:
|
||||
"""Deserialize a treespec from bytes."""
|
||||
treespec = pickle.loads(serialized)
|
||||
if not isinstance(treespec, PyTreeSpec):
|
||||
raise TypeError(
|
||||
f"treespec_loads(serialized): Expected to return an instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
def treespec_loads(serialized: str) -> PyTreeSpec:
|
||||
"""Deserialize a treespec from a JSON string."""
|
||||
from ._pytree import (
|
||||
tree_unflatten as _tree_unflatten,
|
||||
treespec_loads as _treespec_loads,
|
||||
)
|
||||
|
||||
orig_treespec = _treespec_loads(serialized)
|
||||
dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec)
|
||||
treespec = tree_structure(dummy_tree)
|
||||
return treespec
|
||||
|
||||
|
||||
class PyTreeLeafSpecMeta(type(optree.PyTreeSpec)): # type: ignore[misc]
|
||||
class PyTreeLeafSpecMeta(type(PyTreeSpec)): # type: ignore[misc]
|
||||
def __instancecheck__(self, instance: object) -> bool:
|
||||
return isinstance(instance, optree.PyTreeSpec) and instance.is_leaf()
|
||||
return isinstance(instance, PyTreeSpec) and instance.is_leaf()
|
||||
|
||||
|
||||
class PyTreeLeafSpec(optree.PyTreeSpec, metaclass=PyTreeLeafSpecMeta):
|
||||
class PyTreeLeafSpec(PyTreeSpec, metaclass=PyTreeLeafSpecMeta):
|
||||
def __new__(cls, none_is_leaf: bool = True) -> "PyTreeLeafSpec":
|
||||
return optree.treespec_leaf(none_is_leaf=none_is_leaf) # type: ignore[return-value]
|
||||
|
@ -1,15 +1,3 @@
|
||||
from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, Optional, TypeVar, overload, Union
|
||||
from collections import namedtuple, OrderedDict
|
||||
import dataclasses
|
||||
import json
|
||||
import warnings
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
S = TypeVar('S')
|
||||
U = TypeVar('U')
|
||||
R = TypeVar('R')
|
||||
|
||||
"""
|
||||
Contains utility functions for working with nested python data structures.
|
||||
|
||||
@ -27,6 +15,32 @@ This pytree implementation is not very performant due to Python overhead
|
||||
To improve the performance we can move parts of the implementation to C++.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import warnings
|
||||
from collections import deque, namedtuple, OrderedDict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
overload,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
U = TypeVar("U")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1
|
||||
|
||||
Context = Any
|
||||
@ -226,6 +240,8 @@ class TreeSpec:
|
||||
repr_suffix: str = f'{children_specs_str}])'
|
||||
return repr_prefix + repr_suffix
|
||||
|
||||
def is_leaf(self) -> bool:
|
||||
return isinstance(self, LeafSpec)
|
||||
|
||||
class LeafSpec(TreeSpec):
|
||||
def __init__(self) -> None:
|
||||
@ -256,6 +272,13 @@ def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]:
|
||||
|
||||
return result, TreeSpec(node_type, context, children_specs)
|
||||
|
||||
def tree_leaves(pytree: PyTree) -> List[Any]:
|
||||
"""Get a list of leaves of a pytree."""
|
||||
return tree_flatten(pytree)[0]
|
||||
|
||||
def tree_structure(pytree: PyTree) -> TreeSpec:
|
||||
"""Get the TreeSpec for a pytree."""
|
||||
return tree_flatten(pytree)[1]
|
||||
|
||||
def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree:
|
||||
"""Given a list of values and a TreeSpec, builds a pytree.
|
||||
@ -290,6 +313,11 @@ def tree_map(fn: Any, pytree: PyTree) -> PyTree:
|
||||
flat_args, spec = tree_flatten(pytree)
|
||||
return tree_unflatten([fn(i) for i in flat_args], spec)
|
||||
|
||||
def tree_map_(fn: Any, pytree: PyTree) -> PyTree:
|
||||
flat_args, _ = tree_flatten(pytree)
|
||||
deque(map(fn, flat_args), maxlen=0) # consume and exhaust the iterable
|
||||
return pytree
|
||||
|
||||
Type2 = Tuple[Type[T], Type[S]]
|
||||
Type3 = Tuple[Type[T], Type[S], Type[U]]
|
||||
TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
|
||||
@ -359,6 +387,21 @@ def tree_map_only(ty: Type3[T, S, U], fn: Fn3[T, S, U, Any], pytree: PyTree) ->
|
||||
def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree:
|
||||
return tree_map(map_only(ty)(fn), pytree)
|
||||
|
||||
@overload
|
||||
def tree_map_only_(ty: Type[T], fn: Fn[T, Any], pytree: PyTree) -> PyTree:
|
||||
...
|
||||
|
||||
@overload
|
||||
def tree_map_only_(ty: Type2[T, S], fn: Fn2[T, S, Any], pytree: PyTree) -> PyTree:
|
||||
...
|
||||
|
||||
@overload
|
||||
def tree_map_only_(ty: Type3[T, S, U], fn: Fn3[T, S, U, Any], pytree: PyTree) -> PyTree:
|
||||
...
|
||||
|
||||
def tree_map_only_(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree:
|
||||
return tree_map_(map_only(ty)(fn), pytree)
|
||||
|
||||
def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool:
|
||||
flat_args, _ = tree_flatten(pytree)
|
||||
return all(map(pred, flat_args))
|
||||
@ -433,15 +476,16 @@ def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[A
|
||||
return result
|
||||
|
||||
|
||||
"""
|
||||
_TreeSpecSchema is the schema used to serialize the TreeSpec
|
||||
It contains the following fields:
|
||||
- type: A string name of the type. null for the case of a LeafSpec.
|
||||
- context: Any format which is json dumpable
|
||||
- children_spec: A list of children serialized specs.
|
||||
"""
|
||||
@dataclasses.dataclass
|
||||
class _TreeSpecSchema:
|
||||
"""
|
||||
_TreeSpecSchema is the schema used to serialize the TreeSpec
|
||||
It contains the following fields:
|
||||
- type: A string name of the type. null for the case of a LeafSpec.
|
||||
- context: Any format which is json dumpable
|
||||
- children_spec: A list of children serialized specs.
|
||||
"""
|
||||
|
||||
type: Optional[str]
|
||||
context: DumpableContext
|
||||
children_spec: List['_TreeSpecSchema']
|
||||
@ -517,6 +561,12 @@ _SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
|
||||
|
||||
|
||||
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
raise TypeError(
|
||||
f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of "
|
||||
f"TreeSpec but got item of type {type(treespec)}.",
|
||||
)
|
||||
|
||||
if protocol is None:
|
||||
protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL
|
||||
|
||||
|
Reference in New Issue
Block a user