mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160842 Approved by: https://github.com/Skylion007
1547 lines
55 KiB
Python
1547 lines
55 KiB
Python
# Owner(s): ["module: pytree"]
|
|
|
|
import enum
|
|
import inspect
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import unittest
|
|
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
|
|
from dataclasses import dataclass, field
|
|
from enum import auto
|
|
from typing import Any, NamedTuple, Optional
|
|
|
|
import torch
|
|
import torch.utils._pytree as python_pytree
|
|
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
IS_FBCODE,
|
|
parametrize,
|
|
run_tests,
|
|
subtest,
|
|
TestCase,
|
|
)
|
|
|
|
|
|
pytree_modules = {
|
|
"python": python_pytree,
|
|
}
|
|
if not IS_FBCODE:
|
|
import torch.utils._cxx_pytree as cxx_pytree
|
|
|
|
pytree_modules["cxx"] = cxx_pytree
|
|
else:
|
|
# optree is not yet enabled in fbcode, so just re-test the python implementation
|
|
cxx_pytree = python_pytree
|
|
|
|
|
|
parametrize_pytree_module = parametrize(
|
|
"pytree",
|
|
[subtest(module, name=name) for name, module in pytree_modules.items()],
|
|
)
|
|
|
|
|
|
GlobalPoint = namedtuple("GlobalPoint", ["x", "y"])
|
|
|
|
|
|
class GlobalDummyType:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
|
|
cxx_pytree.register_pytree_node(
|
|
GlobalDummyType,
|
|
lambda dummy: ([dummy.x, dummy.y], None),
|
|
lambda xs, _: GlobalDummyType(*xs),
|
|
serialized_type_name="GlobalDummyType",
|
|
)
|
|
|
|
|
|
class TestEnum(enum.Enum):
|
|
A = auto()
|
|
|
|
|
|
python_leafspec = python_pytree.LeafSpec()
|
|
|
|
|
|
class TestGenericPytree(TestCase):
|
|
def test_aligned_public_apis(self):
|
|
public_apis = python_pytree.__all__
|
|
|
|
self.assertEqual(public_apis, cxx_pytree.__all__)
|
|
|
|
for name in public_apis:
|
|
cxx_api = getattr(cxx_pytree, name)
|
|
python_api = getattr(python_pytree, name)
|
|
|
|
self.assertEqual(inspect.isclass(cxx_api), inspect.isclass(python_api))
|
|
self.assertEqual(
|
|
inspect.isfunction(cxx_api),
|
|
inspect.isfunction(python_api),
|
|
)
|
|
if inspect.isfunction(cxx_api):
|
|
cxx_signature = inspect.signature(cxx_api)
|
|
python_signature = inspect.signature(python_api)
|
|
|
|
# Check the parameter names are the same.
|
|
cxx_param_names = list(cxx_signature.parameters)
|
|
python_param_names = list(python_signature.parameters)
|
|
self.assertEqual(cxx_param_names, python_param_names)
|
|
|
|
# Check the positional parameters are the same.
|
|
cxx_positional_param_names = [
|
|
n
|
|
for n, p in cxx_signature.parameters.items()
|
|
if (
|
|
p.kind
|
|
in {
|
|
inspect.Parameter.POSITIONAL_ONLY,
|
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
}
|
|
)
|
|
]
|
|
python_positional_param_names = [
|
|
n
|
|
for n, p in python_signature.parameters.items()
|
|
if (
|
|
p.kind
|
|
in {
|
|
inspect.Parameter.POSITIONAL_ONLY,
|
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
}
|
|
)
|
|
]
|
|
self.assertEqual(
|
|
cxx_positional_param_names,
|
|
python_positional_param_names,
|
|
)
|
|
|
|
for python_name, python_param in python_signature.parameters.items():
|
|
self.assertIn(python_name, cxx_signature.parameters)
|
|
cxx_param = cxx_signature.parameters[python_name]
|
|
|
|
# Check parameter kinds and default values are the same.
|
|
self.assertEqual(cxx_param.kind, python_param.kind)
|
|
self.assertEqual(cxx_param.default, python_param.default)
|
|
|
|
# Check parameter annotations are the same.
|
|
if "TreeSpec" in str(cxx_param.annotation):
|
|
self.assertIn("TreeSpec", str(python_param.annotation))
|
|
self.assertEqual(
|
|
re.sub(
|
|
r"(?:\b)([\w\.]*)TreeSpec(?:\b)",
|
|
"TreeSpec",
|
|
str(cxx_param.annotation),
|
|
),
|
|
re.sub(
|
|
r"(?:\b)([\w\.]*)TreeSpec(?:\b)",
|
|
"TreeSpec",
|
|
str(python_param.annotation),
|
|
),
|
|
msg=(
|
|
f"C++ parameter {cxx_param} "
|
|
f"does not match Python parameter {python_param} "
|
|
f"for API `{name}`"
|
|
),
|
|
)
|
|
else:
|
|
self.assertEqual(
|
|
cxx_param.annotation,
|
|
python_param.annotation,
|
|
msg=(
|
|
f"C++ parameter {cxx_param} "
|
|
f"does not match Python parameter {python_param} "
|
|
f"for API `{name}`"
|
|
),
|
|
)
|
|
|
|
@parametrize_pytree_module
|
|
def test_register_pytree_node(self, pytree):
|
|
class MyDict(UserDict):
|
|
pass
|
|
|
|
d = MyDict(a=1, b=2, c=3)
|
|
|
|
# Custom types are leaf nodes by default
|
|
values, spec = pytree.tree_flatten(d)
|
|
self.assertEqual(values, [d])
|
|
self.assertIs(values[0], d)
|
|
self.assertEqual(d, pytree.tree_unflatten(values, spec))
|
|
self.assertTrue(spec.is_leaf())
|
|
|
|
# Register MyDict as a pytree node
|
|
pytree.register_pytree_node(
|
|
MyDict,
|
|
lambda d: (list(d.values()), list(d.keys())),
|
|
lambda values, keys: MyDict(zip(keys, values)),
|
|
)
|
|
|
|
values, spec = pytree.tree_flatten(d)
|
|
self.assertEqual(values, [1, 2, 3])
|
|
self.assertEqual(d, pytree.tree_unflatten(values, spec))
|
|
|
|
# Do not allow registering the same type twice
|
|
with self.assertRaisesRegex(ValueError, "already registered"):
|
|
pytree.register_pytree_node(
|
|
MyDict,
|
|
lambda d: (list(d.values()), list(d.keys())),
|
|
lambda values, keys: MyDict(zip(keys, values)),
|
|
)
|
|
|
|
@parametrize_pytree_module
|
|
def test_flatten_unflatten_leaf(self, pytree):
|
|
def run_test_with_leaf(leaf):
|
|
values, treespec = pytree.tree_flatten(leaf)
|
|
self.assertEqual(values, [leaf])
|
|
self.assertEqual(treespec, pytree.LeafSpec())
|
|
|
|
unflattened = pytree.tree_unflatten(values, treespec)
|
|
self.assertEqual(unflattened, leaf)
|
|
|
|
run_test_with_leaf(1)
|
|
run_test_with_leaf(1.0)
|
|
run_test_with_leaf(None)
|
|
run_test_with_leaf(bool)
|
|
run_test_with_leaf(torch.randn(3, 3))
|
|
|
|
@parametrize(
|
|
"pytree,gen_expected_fn",
|
|
[
|
|
subtest(
|
|
(
|
|
python_pytree,
|
|
lambda tup: python_pytree.TreeSpec(
|
|
tuple, None, [python_leafspec for _ in tup]
|
|
),
|
|
),
|
|
name="python",
|
|
),
|
|
subtest(
|
|
(cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))),
|
|
name="cxx",
|
|
),
|
|
],
|
|
)
|
|
def test_flatten_unflatten_tuple(self, pytree, gen_expected_fn):
|
|
def run_test(tup):
|
|
expected_spec = gen_expected_fn(tup)
|
|
values, treespec = pytree.tree_flatten(tup)
|
|
self.assertIsInstance(values, list)
|
|
self.assertEqual(values, list(tup))
|
|
self.assertEqual(treespec, expected_spec)
|
|
|
|
unflattened = pytree.tree_unflatten(values, treespec)
|
|
self.assertEqual(unflattened, tup)
|
|
self.assertIsInstance(unflattened, tuple)
|
|
|
|
run_test(())
|
|
run_test((1.0,))
|
|
run_test((1.0, 2))
|
|
run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))
|
|
|
|
@parametrize(
|
|
"pytree,gen_expected_fn",
|
|
[
|
|
subtest(
|
|
(
|
|
python_pytree,
|
|
lambda lst: python_pytree.TreeSpec(
|
|
list, None, [python_leafspec for _ in lst]
|
|
),
|
|
),
|
|
name="python",
|
|
),
|
|
subtest(
|
|
(cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))),
|
|
name="cxx",
|
|
),
|
|
],
|
|
)
|
|
def test_flatten_unflatten_list(self, pytree, gen_expected_fn):
|
|
def run_test(lst):
|
|
expected_spec = gen_expected_fn(lst)
|
|
values, treespec = pytree.tree_flatten(lst)
|
|
self.assertIsInstance(values, list)
|
|
self.assertEqual(values, lst)
|
|
self.assertEqual(treespec, expected_spec)
|
|
|
|
unflattened = pytree.tree_unflatten(values, treespec)
|
|
self.assertEqual(unflattened, lst)
|
|
self.assertIsInstance(unflattened, list)
|
|
|
|
run_test([])
|
|
run_test([1.0, 2])
|
|
run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])
|
|
|
|
@parametrize(
|
|
"pytree,gen_expected_fn",
|
|
[
|
|
subtest(
|
|
(
|
|
python_pytree,
|
|
lambda dct: python_pytree.TreeSpec(
|
|
dict,
|
|
list(dct.keys()),
|
|
[python_leafspec for _ in dct.values()],
|
|
),
|
|
),
|
|
name="python",
|
|
),
|
|
subtest(
|
|
(
|
|
cxx_pytree,
|
|
lambda dct: cxx_pytree.tree_structure(dict.fromkeys(dct, 0)),
|
|
),
|
|
name="cxx",
|
|
),
|
|
],
|
|
)
|
|
def test_flatten_unflatten_dict(self, pytree, gen_expected_fn):
|
|
def run_test(dct):
|
|
expected_spec = gen_expected_fn(dct)
|
|
values, treespec = pytree.tree_flatten(dct)
|
|
self.assertIsInstance(values, list)
|
|
self.assertEqual(values, list(dct.values()))
|
|
self.assertEqual(treespec, expected_spec)
|
|
|
|
unflattened = pytree.tree_unflatten(values, treespec)
|
|
self.assertEqual(unflattened, dct)
|
|
self.assertIsInstance(unflattened, dict)
|
|
|
|
run_test({})
|
|
run_test({"a": 1})
|
|
run_test({"abcdefg": torch.randn(2, 3)})
|
|
run_test({1: torch.randn(2, 3)})
|
|
run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)})
|
|
|
|
@parametrize(
|
|
"pytree,gen_expected_fn",
|
|
[
|
|
subtest(
|
|
(
|
|
python_pytree,
|
|
lambda odict: python_pytree.TreeSpec(
|
|
OrderedDict,
|
|
list(odict.keys()),
|
|
[python_leafspec for _ in odict.values()],
|
|
),
|
|
),
|
|
name="python",
|
|
),
|
|
subtest(
|
|
(
|
|
cxx_pytree,
|
|
lambda odict: cxx_pytree.tree_structure(
|
|
OrderedDict.fromkeys(odict, 0)
|
|
),
|
|
),
|
|
name="cxx",
|
|
),
|
|
],
|
|
)
|
|
def test_flatten_unflatten_ordereddict(self, pytree, gen_expected_fn):
|
|
def run_test(odict):
|
|
expected_spec = gen_expected_fn(odict)
|
|
values, treespec = pytree.tree_flatten(odict)
|
|
self.assertIsInstance(values, list)
|
|
self.assertEqual(values, list(odict.values()))
|
|
self.assertEqual(treespec, expected_spec)
|
|
|
|
unflattened = pytree.tree_unflatten(values, treespec)
|
|
self.assertEqual(unflattened, odict)
|
|
self.assertIsInstance(unflattened, OrderedDict)
|
|
|
|
od = OrderedDict()
|
|
run_test(od)
|
|
|
|
od["b"] = 1
|
|
od["a"] = torch.tensor(3.14)
|
|
run_test(od)
|
|
|
|
@parametrize(
|
|
"pytree,gen_expected_fn",
|
|
[
|
|
subtest(
|
|
(
|
|
python_pytree,
|
|
lambda ddct: python_pytree.TreeSpec(
|
|
defaultdict,
|
|
[ddct.default_factory, list(ddct.keys())],
|
|
[python_leafspec for _ in ddct.values()],
|
|
),
|
|
),
|
|
name="python",
|
|
),
|
|
subtest(
|
|
(
|
|
cxx_pytree,
|
|
lambda ddct: cxx_pytree.tree_structure(
|
|
defaultdict(ddct.default_factory, dict.fromkeys(ddct, 0))
|
|
),
|
|
),
|
|
name="cxx",
|
|
),
|
|
],
|
|
)
|
|
def test_flatten_unflatten_defaultdict(self, pytree, gen_expected_fn):
|
|
def run_test(ddct):
|
|
expected_spec = gen_expected_fn(ddct)
|
|
values, treespec = pytree.tree_flatten(ddct)
|
|
self.assertIsInstance(values, list)
|
|
self.assertEqual(values, list(ddct.values()))
|
|
self.assertEqual(treespec, expected_spec)
|
|
|
|
unflattened = pytree.tree_unflatten(values, treespec)
|
|
self.assertEqual(unflattened, ddct)
|
|
self.assertEqual(unflattened.default_factory, ddct.default_factory)
|
|
self.assertIsInstance(unflattened, defaultdict)
|
|
|
|
run_test(defaultdict(list, {}))
|
|
run_test(defaultdict(int, {"a": 1}))
|
|
run_test(defaultdict(int, {"abcdefg": torch.randn(2, 3)}))
|
|
run_test(defaultdict(int, {1: torch.randn(2, 3)}))
|
|
run_test(defaultdict(int, {"a": 1, "b": 2, "c": torch.randn(2, 3)}))
|
|
|
|
@parametrize(
|
|
"pytree,gen_expected_fn",
|
|
[
|
|
subtest(
|
|
(
|
|
python_pytree,
|
|
lambda deq: python_pytree.TreeSpec(
|
|
deque, deq.maxlen, [python_leafspec for _ in deq]
|
|
),
|
|
),
|
|
name="python",
|
|
),
|
|
subtest(
|
|
(
|
|
cxx_pytree,
|
|
lambda deq: cxx_pytree.tree_structure(
|
|
deque(deq, maxlen=deq.maxlen)
|
|
),
|
|
),
|
|
name="cxx",
|
|
),
|
|
],
|
|
)
|
|
def test_flatten_unflatten_deque(self, pytree, gen_expected_fn):
|
|
def run_test(deq):
|
|
expected_spec = gen_expected_fn(deq)
|
|
values, treespec = pytree.tree_flatten(deq)
|
|
self.assertIsInstance(values, list)
|
|
self.assertEqual(values, list(deq))
|
|
self.assertEqual(treespec, expected_spec)
|
|
|
|
unflattened = pytree.tree_unflatten(values, treespec)
|
|
self.assertEqual(unflattened, deq)
|
|
self.assertEqual(unflattened.maxlen, deq.maxlen)
|
|
self.assertIsInstance(unflattened, deque)
|
|
|
|
run_test(deque([]))
|
|
run_test(deque([1.0, 2]))
|
|
run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
|
|
|
|
@parametrize_pytree_module
|
|
def test_flatten_unflatten_namedtuple(self, pytree):
|
|
Point = namedtuple("Point", ["x", "y"])
|
|
|
|
def run_test(tup):
|
|
if pytree is python_pytree:
|
|
expected_spec = python_pytree.TreeSpec(
|
|
namedtuple, Point, [python_leafspec for _ in tup]
|
|
)
|
|
else:
|
|
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
|
|
values, treespec = pytree.tree_flatten(tup)
|
|
self.assertIsInstance(values, list)
|
|
self.assertEqual(values, list(tup))
|
|
self.assertEqual(treespec, expected_spec)
|
|
|
|
unflattened = pytree.tree_unflatten(values, treespec)
|
|
self.assertEqual(unflattened, tup)
|
|
self.assertIsInstance(unflattened, Point)
|
|
|
|
run_test(Point(1.0, 2))
|
|
run_test(Point(torch.tensor(1.0), 2))
|
|
|
|
@parametrize(
|
|
"op",
|
|
[
|
|
subtest(torch.max, name="max"),
|
|
subtest(torch.min, name="min"),
|
|
],
|
|
)
|
|
@parametrize_pytree_module
|
|
def test_flatten_unflatten_return_types(self, pytree, op):
|
|
x = torch.randn(3, 3)
|
|
expected = op(x, dim=0)
|
|
|
|
values, spec = pytree.tree_flatten(expected)
|
|
# Check that values is actually List[Tensor] and not (ReturnType(...),)
|
|
for value in values:
|
|
self.assertIsInstance(value, torch.Tensor)
|
|
result = pytree.tree_unflatten(values, spec)
|
|
|
|
self.assertEqual(type(result), type(expected))
|
|
self.assertEqual(result, expected)
|
|
|
|
@parametrize_pytree_module
|
|
def test_flatten_unflatten_nested(self, pytree):
|
|
def run_test(tree):
|
|
values, treespec = pytree.tree_flatten(tree)
|
|
self.assertIsInstance(values, list)
|
|
self.assertEqual(len(values), treespec.num_leaves)
|
|
|
|
# NB: python basic data structures (dict list tuple) all have
|
|
# contents equality defined on them, so the following works for them.
|
|
unflattened = pytree.tree_unflatten(values, treespec)
|
|
self.assertEqual(unflattened, tree)
|
|
|
|
cases = [
|
|
[()],
|
|
([],),
|
|
{"a": ()},
|
|
{"a": 0, "b": [{"c": 1}]},
|
|
{"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)},
|
|
]
|
|
for case in cases:
|
|
run_test(case)
|
|
|
|
@parametrize_pytree_module
|
|
def test_flatten_with_is_leaf(self, pytree):
|
|
def run_test(tree, one_level_leaves):
|
|
values, treespec = pytree.tree_flatten(
|
|
tree, is_leaf=lambda x: x is not tree
|
|
)
|
|
self.assertIsInstance(values, list)
|
|
self.assertEqual(len(values), treespec.num_nodes - 1)
|
|
self.assertEqual(len(values), treespec.num_leaves)
|
|
self.assertEqual(len(values), treespec.num_children)
|
|
self.assertEqual(values, one_level_leaves)
|
|
|
|
self.assertEqual(
|
|
treespec,
|
|
pytree.tree_structure(
|
|
pytree.tree_unflatten([0] * treespec.num_leaves, treespec)
|
|
),
|
|
)
|
|
|
|
unflattened = pytree.tree_unflatten(values, treespec)
|
|
self.assertEqual(unflattened, tree)
|
|
|
|
cases = [
|
|
([()], [()]),
|
|
(([],), [[]]),
|
|
({"a": ()}, [()]),
|
|
({"a": 0, "b": [{"c": 1}]}, [0, [{"c": 1}]]),
|
|
(
|
|
{
|
|
"a": 0,
|
|
"b": [1, {"c": 2}, torch.ones(3)],
|
|
"c": (torch.zeros(2, 3), 1),
|
|
},
|
|
[0, [1, {"c": 2}, torch.ones(3)], (torch.zeros(2, 3), 1)],
|
|
),
|
|
]
|
|
for case in cases:
|
|
run_test(*case)
|
|
|
|
@parametrize_pytree_module
|
|
def test_tree_map(self, pytree):
|
|
def run_test(tree):
|
|
def f(x):
|
|
return x * 3
|
|
|
|
sm1 = sum(map(f, pytree.tree_leaves(tree)))
|
|
sm2 = sum(pytree.tree_leaves(pytree.tree_map(f, tree)))
|
|
self.assertEqual(sm1, sm2)
|
|
|
|
def invf(x):
|
|
return x // 3
|
|
|
|
self.assertEqual(
|
|
pytree.tree_map(invf, pytree.tree_map(f, tree)),
|
|
tree,
|
|
)
|
|
|
|
cases = [
|
|
[()],
|
|
([],),
|
|
{"a": ()},
|
|
{"a": 1, "b": [{"c": 2}]},
|
|
{"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
|
|
]
|
|
for case in cases:
|
|
run_test(case)
|
|
|
|
@parametrize_pytree_module
|
|
def test_tree_map_multi_inputs(self, pytree):
|
|
def run_test(tree):
|
|
def f(x, y, z):
|
|
return x, [y, (z, 0)]
|
|
|
|
tree_x = tree
|
|
tree_y = pytree.tree_map(lambda x: (x + 1,), tree)
|
|
tree_z = pytree.tree_map(lambda x: {"a": x * 2, "b": 2}, tree)
|
|
|
|
self.assertEqual(
|
|
pytree.tree_map(f, tree_x, tree_y, tree_z),
|
|
pytree.tree_map(lambda x: f(x, (x + 1,), {"a": x * 2, "b": 2}), tree),
|
|
)
|
|
|
|
cases = [
|
|
[()],
|
|
([],),
|
|
{"a": ()},
|
|
{"a": 1, "b": [{"c": 2}]},
|
|
{"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
|
|
]
|
|
for case in cases:
|
|
run_test(case)
|
|
|
|
@parametrize_pytree_module
|
|
def test_tree_map_only(self, pytree):
|
|
self.assertEqual(pytree.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"])
|
|
|
|
@parametrize_pytree_module
|
|
def test_tree_map_only_predicate_fn(self, pytree):
|
|
self.assertEqual(
|
|
pytree.tree_map_only(lambda x: x == 0, lambda x: x + 2, [0, 1]), [2, 1]
|
|
)
|
|
|
|
@parametrize_pytree_module
|
|
def test_tree_all_any(self, pytree):
|
|
self.assertTrue(pytree.tree_all(lambda x: x % 2, [1, 3]))
|
|
self.assertFalse(pytree.tree_all(lambda x: x % 2, [0, 1]))
|
|
self.assertTrue(pytree.tree_any(lambda x: x % 2, [0, 1]))
|
|
self.assertFalse(pytree.tree_any(lambda x: x % 2, [0, 2]))
|
|
self.assertTrue(pytree.tree_all_only(int, lambda x: x % 2, [1, 3, "a"]))
|
|
self.assertFalse(pytree.tree_all_only(int, lambda x: x % 2, [0, 1, "a"]))
|
|
self.assertTrue(pytree.tree_any_only(int, lambda x: x % 2, [0, 1, "a"]))
|
|
self.assertFalse(pytree.tree_any_only(int, lambda x: x % 2, [0, 2, "a"]))
|
|
|
|
@parametrize_pytree_module
|
|
def test_broadcast_to_and_flatten(self, pytree):
|
|
cases = [
|
|
(1, (), []),
|
|
# Same (flat) structures
|
|
((1,), (0,), [1]),
|
|
([1], [0], [1]),
|
|
((1, 2, 3), (0, 0, 0), [1, 2, 3]),
|
|
({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]),
|
|
# Mismatched (flat) structures
|
|
([1], (0,), None),
|
|
([1], (0,), None),
|
|
((1,), [0], None),
|
|
((1, 2, 3), (0, 0), None),
|
|
({"a": 1, "b": 2}, {"a": 0}, None),
|
|
({"a": 1, "b": 2}, {"a": 0, "c": 0}, None),
|
|
({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None),
|
|
# Same (nested) structures
|
|
((1, [2, 3]), (0, [0, 0]), [1, 2, 3]),
|
|
((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]),
|
|
# Mismatched (nested) structures
|
|
((1, [2, 3]), (0, (0, 0)), None),
|
|
((1, [2, 3]), (0, [0, 0, 0]), None),
|
|
# Broadcasting single value
|
|
(1, (0, 0, 0), [1, 1, 1]),
|
|
(1, [0, 0, 0], [1, 1, 1]),
|
|
(1, {"a": 0, "b": 0}, [1, 1]),
|
|
(1, (0, [0, [0]], 0), [1, 1, 1, 1]),
|
|
(1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]),
|
|
# Broadcast multiple things
|
|
((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]),
|
|
((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]),
|
|
(([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
|
|
]
|
|
for tree, to_tree, expected in cases:
|
|
_, to_spec = pytree.tree_flatten(to_tree)
|
|
result = pytree._broadcast_to_and_flatten(tree, to_spec)
|
|
self.assertEqual(result, expected, msg=str([tree, to_spec, expected]))
|
|
|
|
@parametrize_pytree_module
|
|
def test_pytree_serialize_bad_input(self, pytree):
|
|
with self.assertRaises(TypeError):
|
|
pytree.treespec_dumps("random_blurb")
|
|
|
|
@parametrize_pytree_module
|
|
def test_is_namedtuple(self, pytree):
|
|
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
|
|
|
|
class DirectNamedTuple2(NamedTuple):
|
|
x: int
|
|
y: int
|
|
|
|
class IndirectNamedTuple1(DirectNamedTuple1):
|
|
pass
|
|
|
|
class IndirectNamedTuple2(DirectNamedTuple2):
|
|
pass
|
|
|
|
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1(0, 1)))
|
|
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2(0, 1)))
|
|
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1(0, 1)))
|
|
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2(0, 1)))
|
|
self.assertFalse(pytree.is_namedtuple(time.gmtime()))
|
|
self.assertFalse(pytree.is_namedtuple((0, 1)))
|
|
self.assertFalse(pytree.is_namedtuple([0, 1]))
|
|
self.assertFalse(pytree.is_namedtuple({0: 1, 1: 2}))
|
|
self.assertFalse(pytree.is_namedtuple({0, 1}))
|
|
self.assertFalse(pytree.is_namedtuple(1))
|
|
|
|
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1))
|
|
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2))
|
|
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1))
|
|
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2))
|
|
self.assertFalse(pytree.is_namedtuple(time.struct_time))
|
|
self.assertFalse(pytree.is_namedtuple(tuple))
|
|
self.assertFalse(pytree.is_namedtuple(list))
|
|
|
|
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple1))
|
|
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple2))
|
|
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple1))
|
|
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple2))
|
|
self.assertFalse(pytree.is_namedtuple_class(time.struct_time))
|
|
self.assertFalse(pytree.is_namedtuple_class(tuple))
|
|
self.assertFalse(pytree.is_namedtuple_class(list))
|
|
|
|
@parametrize_pytree_module
|
|
def test_is_structseq(self, pytree):
|
|
class FakeStructSeq(tuple):
|
|
n_fields = 2
|
|
n_sequence_fields = 2
|
|
n_unnamed_fields = 0
|
|
|
|
__slots__ = ()
|
|
__match_args__ = ("x", "y")
|
|
|
|
def __new__(cls, sequence):
|
|
return super().__new__(cls, sequence)
|
|
|
|
@property
|
|
def x(self):
|
|
return self[0]
|
|
|
|
@property
|
|
def y(self):
|
|
return self[1]
|
|
|
|
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
|
|
|
|
class DirectNamedTuple2(NamedTuple):
|
|
x: int
|
|
y: int
|
|
|
|
self.assertFalse(pytree.is_structseq(FakeStructSeq((0, 1))))
|
|
self.assertTrue(pytree.is_structseq(time.gmtime()))
|
|
self.assertFalse(pytree.is_structseq(DirectNamedTuple1(0, 1)))
|
|
self.assertFalse(pytree.is_structseq(DirectNamedTuple2(0, 1)))
|
|
self.assertFalse(pytree.is_structseq((0, 1)))
|
|
self.assertFalse(pytree.is_structseq([0, 1]))
|
|
self.assertFalse(pytree.is_structseq({0: 1, 1: 2}))
|
|
self.assertFalse(pytree.is_structseq({0, 1}))
|
|
self.assertFalse(pytree.is_structseq(1))
|
|
|
|
self.assertFalse(pytree.is_structseq(FakeStructSeq))
|
|
self.assertTrue(pytree.is_structseq(time.struct_time))
|
|
self.assertFalse(pytree.is_structseq(DirectNamedTuple1))
|
|
self.assertFalse(pytree.is_structseq(DirectNamedTuple2))
|
|
self.assertFalse(pytree.is_structseq(tuple))
|
|
self.assertFalse(pytree.is_structseq(list))
|
|
|
|
self.assertFalse(pytree.is_structseq_class(FakeStructSeq))
|
|
self.assertTrue(
|
|
pytree.is_structseq_class(time.struct_time),
|
|
)
|
|
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple1))
|
|
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple2))
|
|
self.assertFalse(pytree.is_structseq_class(tuple))
|
|
self.assertFalse(pytree.is_structseq_class(list))
|
|
|
|
# torch.return_types.* are all PyStructSequence types
|
|
for cls in vars(torch.return_types).values():
|
|
if isinstance(cls, type) and issubclass(cls, tuple):
|
|
self.assertTrue(pytree.is_structseq(cls))
|
|
self.assertTrue(pytree.is_structseq_class(cls))
|
|
self.assertFalse(pytree.is_namedtuple(cls))
|
|
self.assertFalse(pytree.is_namedtuple_class(cls))
|
|
|
|
inst = cls(range(cls.n_sequence_fields))
|
|
self.assertTrue(pytree.is_structseq(inst))
|
|
self.assertTrue(pytree.is_structseq(type(inst)))
|
|
self.assertFalse(pytree.is_structseq_class(inst))
|
|
self.assertTrue(pytree.is_structseq_class(type(inst)))
|
|
self.assertFalse(pytree.is_namedtuple(inst))
|
|
self.assertFalse(pytree.is_namedtuple_class(inst))
|
|
else:
|
|
self.assertFalse(pytree.is_structseq(cls))
|
|
self.assertFalse(pytree.is_structseq_class(cls))
|
|
self.assertFalse(pytree.is_namedtuple(cls))
|
|
self.assertFalse(pytree.is_namedtuple_class(cls))
|
|
|
|
@parametrize_pytree_module
|
|
def test_enum_treespec_roundtrip(self, pytree):
|
|
data = {TestEnum.A: 5}
|
|
spec = pytree.tree_structure(data)
|
|
|
|
serialized = pytree.treespec_dumps(spec)
|
|
deserialized_spec = pytree.treespec_loads(serialized)
|
|
self.assertEqual(spec, deserialized_spec)
|
|
|
|
|
|
class TestPythonPytree(TestCase):
|
|
def test_deprecated_register_pytree_node(self):
|
|
class DummyType:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
with self.assertWarnsRegex(
|
|
FutureWarning, "torch.utils._pytree._register_pytree_node"
|
|
):
|
|
python_pytree._register_pytree_node(
|
|
DummyType,
|
|
lambda dummy: ([dummy.x, dummy.y], None),
|
|
lambda xs, _: DummyType(*xs),
|
|
)
|
|
|
|
with self.assertWarnsRegex(UserWarning, "already registered"):
|
|
python_pytree._register_pytree_node(
|
|
DummyType,
|
|
lambda dummy: ([dummy.x, dummy.y], None),
|
|
lambda xs, _: DummyType(*xs),
|
|
)
|
|
|
|
def test_import_pytree_doesnt_import_optree(self):
|
|
# importing torch.utils._pytree shouldn't import optree.
|
|
# only importing torch.utils._cxx_pytree should.
|
|
script = """
|
|
import sys
|
|
import torch
|
|
import torch.utils._pytree
|
|
assert "torch.utils._pytree" in sys.modules
|
|
if "torch.utils._cxx_pytree" in sys.modules:
|
|
raise RuntimeError("importing torch.utils._pytree should not import torch.utils._cxx_pytree")
|
|
if "optree" in sys.modules:
|
|
raise RuntimeError("importing torch.utils._pytree should not import optree")
|
|
"""
|
|
try:
|
|
subprocess.check_output(
|
|
[sys.executable, "-c", script],
|
|
stderr=subprocess.STDOUT,
|
|
# On Windows, opening the subprocess with the default CWD makes `import torch`
|
|
# fail, so just set CWD to this script's directory
|
|
cwd=os.path.dirname(os.path.realpath(__file__)),
|
|
)
|
|
except subprocess.CalledProcessError as e:
|
|
self.fail(
|
|
msg=(
|
|
"Subprocess exception while attempting to run test: "
|
|
+ e.output.decode("utf-8")
|
|
)
|
|
)
|
|
|
|
def test_treespec_equality(self):
|
|
self.assertEqual(
|
|
python_pytree.LeafSpec(),
|
|
python_pytree.LeafSpec(),
|
|
)
|
|
self.assertEqual(
|
|
python_pytree.TreeSpec(list, None, []),
|
|
python_pytree.TreeSpec(list, None, []),
|
|
)
|
|
self.assertEqual(
|
|
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
|
|
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
|
|
)
|
|
self.assertFalse(
|
|
python_pytree.TreeSpec(tuple, None, [])
|
|
== python_pytree.TreeSpec(list, None, []),
|
|
)
|
|
self.assertTrue(
|
|
python_pytree.TreeSpec(tuple, None, [])
|
|
!= python_pytree.TreeSpec(list, None, []),
|
|
)
|
|
|
|
def test_treespec_repr(self):
|
|
# Check that it looks sane
|
|
tree = (0, [0, 0, [0]])
|
|
spec = python_pytree.tree_structure(tree)
|
|
self.assertEqual(
|
|
repr(spec),
|
|
(
|
|
"TreeSpec(tuple, None, [*,\n"
|
|
" TreeSpec(list, None, [*,\n"
|
|
" *,\n"
|
|
" TreeSpec(list, None, [*])])])"
|
|
),
|
|
)
|
|
|
|
@parametrize(
|
|
"spec",
|
|
[
|
|
# python_pytree.tree_structure([])
|
|
python_pytree.TreeSpec(list, None, []),
|
|
# python_pytree.tree_structure(())
|
|
python_pytree.TreeSpec(tuple, None, []),
|
|
# python_pytree.tree_structure({})
|
|
python_pytree.TreeSpec(dict, [], []),
|
|
# python_pytree.tree_structure([0])
|
|
python_pytree.TreeSpec(list, None, [python_leafspec]),
|
|
# python_pytree.tree_structure([0, 1])
|
|
python_pytree.TreeSpec(
|
|
list,
|
|
None,
|
|
[python_leafspec, python_leafspec],
|
|
),
|
|
# python_pytree.tree_structure((0, 1, 2))
|
|
python_pytree.TreeSpec(
|
|
tuple,
|
|
None,
|
|
[python_leafspec, python_leafspec, python_leafspec],
|
|
),
|
|
# python_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
|
|
python_pytree.TreeSpec(
|
|
dict,
|
|
["a", "b", "c"],
|
|
[python_leafspec, python_leafspec, python_leafspec],
|
|
),
|
|
# python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
|
|
python_pytree.TreeSpec(
|
|
OrderedDict,
|
|
["a", "b", "c"],
|
|
[
|
|
python_pytree.TreeSpec(
|
|
tuple,
|
|
None,
|
|
[python_leafspec, python_leafspec],
|
|
),
|
|
python_leafspec,
|
|
python_pytree.TreeSpec(
|
|
dict,
|
|
["a", "b", "c"],
|
|
[python_leafspec, python_leafspec, python_leafspec],
|
|
),
|
|
],
|
|
),
|
|
# python_pytree.tree_structure([(0, 1, [2, 3])])
|
|
python_pytree.TreeSpec(
|
|
list,
|
|
None,
|
|
[
|
|
python_pytree.TreeSpec(
|
|
tuple,
|
|
None,
|
|
[
|
|
python_leafspec,
|
|
python_leafspec,
|
|
python_pytree.TreeSpec(
|
|
list,
|
|
None,
|
|
[python_leafspec, python_leafspec],
|
|
),
|
|
],
|
|
),
|
|
],
|
|
),
|
|
# python_pytree.tree_structure(defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}}))
|
|
python_pytree.TreeSpec(
|
|
defaultdict,
|
|
[list, ["a", "b", "c"]],
|
|
[
|
|
python_pytree.TreeSpec(
|
|
list,
|
|
None,
|
|
[python_leafspec, python_leafspec],
|
|
),
|
|
python_pytree.TreeSpec(
|
|
list,
|
|
None,
|
|
[python_leafspec, python_leafspec],
|
|
),
|
|
python_pytree.TreeSpec(dict, [], []),
|
|
],
|
|
),
|
|
],
|
|
)
|
|
def test_pytree_serialize(self, spec):
|
|
# Ensure that the spec is valid
|
|
self.assertEqual(
|
|
spec,
|
|
python_pytree.tree_structure(
|
|
python_pytree.tree_unflatten([0] * spec.num_leaves, spec)
|
|
),
|
|
)
|
|
|
|
serialized_spec = python_pytree.treespec_dumps(spec)
|
|
self.assertIsInstance(serialized_spec, str)
|
|
self.assertEqual(spec, python_pytree.treespec_loads(serialized_spec))
|
|
|
|
def test_pytree_serialize_defaultdict_enum(self):
|
|
spec = python_pytree.TreeSpec(
|
|
defaultdict,
|
|
[list, [TestEnum.A]],
|
|
[
|
|
python_pytree.TreeSpec(
|
|
list,
|
|
None,
|
|
[
|
|
python_leafspec,
|
|
],
|
|
),
|
|
],
|
|
)
|
|
serialized_spec = python_pytree.treespec_dumps(spec)
|
|
self.assertIsInstance(serialized_spec, str)
|
|
|
|
def test_pytree_serialize_enum(self):
|
|
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_leafspec])
|
|
|
|
serialized_spec = python_pytree.treespec_dumps(spec)
|
|
self.assertIsInstance(serialized_spec, str)
|
|
|
|
def test_pytree_serialize_namedtuple(self):
|
|
Point1 = namedtuple("Point1", ["x", "y"])
|
|
python_pytree._register_namedtuple(
|
|
Point1,
|
|
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1",
|
|
)
|
|
|
|
spec = python_pytree.tree_structure(Point1(1, 2))
|
|
self.assertIs(spec.type, namedtuple)
|
|
roundtrip_spec = python_pytree.treespec_loads(
|
|
python_pytree.treespec_dumps(spec)
|
|
)
|
|
self.assertEqual(spec, roundtrip_spec)
|
|
|
|
class Point2(NamedTuple):
|
|
x: int
|
|
y: int
|
|
|
|
python_pytree._register_namedtuple(
|
|
Point2,
|
|
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2",
|
|
)
|
|
|
|
spec = python_pytree.tree_structure(Point2(1, 2))
|
|
self.assertIs(spec.type, namedtuple)
|
|
roundtrip_spec = python_pytree.treespec_loads(
|
|
python_pytree.treespec_dumps(spec)
|
|
)
|
|
self.assertEqual(spec, roundtrip_spec)
|
|
|
|
class Point3(Point2):
|
|
pass
|
|
|
|
python_pytree._register_namedtuple(
|
|
Point3,
|
|
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point3",
|
|
)
|
|
|
|
spec = python_pytree.tree_structure(Point3(1, 2))
|
|
self.assertIs(spec.type, namedtuple)
|
|
roundtrip_spec = python_pytree.treespec_loads(
|
|
python_pytree.treespec_dumps(spec)
|
|
)
|
|
self.assertEqual(spec, roundtrip_spec)
|
|
|
|
def test_pytree_serialize_namedtuple_bad(self):
|
|
DummyType = namedtuple("DummyType", ["x", "y"])
|
|
|
|
spec = python_pytree.tree_structure(DummyType(1, 2))
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError, "Please register using `_register_namedtuple`"
|
|
):
|
|
python_pytree.treespec_dumps(spec)
|
|
|
|
def test_pytree_custom_type_serialize_bad(self):
|
|
class DummyType:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
python_pytree.register_pytree_node(
|
|
DummyType,
|
|
lambda dummy: ([dummy.x, dummy.y], None),
|
|
lambda xs, _: DummyType(*xs),
|
|
)
|
|
|
|
spec = python_pytree.tree_structure(DummyType(1, 2))
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError, "No registered serialization name"
|
|
):
|
|
python_pytree.treespec_dumps(spec)
|
|
|
|
def test_pytree_custom_type_serialize(self):
|
|
class DummyType:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
python_pytree.register_pytree_node(
|
|
DummyType,
|
|
lambda dummy: ([dummy.x, dummy.y], None),
|
|
lambda xs, _: DummyType(*xs),
|
|
serialized_type_name="test_pytree_custom_type_serialize.DummyType",
|
|
to_dumpable_context=lambda context: "moo",
|
|
from_dumpable_context=lambda dumpable_context: None,
|
|
)
|
|
spec = python_pytree.tree_structure(DummyType(1, 2))
|
|
serialized_spec = python_pytree.treespec_dumps(spec, 1)
|
|
self.assertIn("moo", serialized_spec)
|
|
roundtrip_spec = python_pytree.treespec_loads(serialized_spec)
|
|
self.assertEqual(roundtrip_spec, spec)
|
|
|
|
def test_pytree_serialize_register_bad(self):
|
|
class DummyType:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Both to_dumpable_context and from_dumpable_context"
|
|
):
|
|
python_pytree.register_pytree_node(
|
|
DummyType,
|
|
lambda dummy: ([dummy.x, dummy.y], None),
|
|
lambda xs, _: DummyType(*xs),
|
|
serialized_type_name="test_pytree_serialize_register_bad.DummyType",
|
|
to_dumpable_context=lambda context: "moo",
|
|
)
|
|
|
|
def test_pytree_context_serialize_bad(self):
|
|
class DummyType:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
python_pytree.register_pytree_node(
|
|
DummyType,
|
|
lambda dummy: ([dummy.x, dummy.y], None),
|
|
lambda xs, _: DummyType(*xs),
|
|
serialized_type_name="test_pytree_serialize_serialize_bad.DummyType",
|
|
to_dumpable_context=lambda context: DummyType,
|
|
from_dumpable_context=lambda dumpable_context: None,
|
|
)
|
|
|
|
spec = python_pytree.tree_structure(DummyType(1, 2))
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError, "Object of type type is not JSON serializable"
|
|
):
|
|
python_pytree.treespec_dumps(spec)
|
|
|
|
def test_pytree_serialize_bad_protocol(self):
|
|
import json
|
|
|
|
Point = namedtuple("Point", ["x", "y"])
|
|
spec = python_pytree.tree_structure(Point(1, 2))
|
|
python_pytree._register_namedtuple(
|
|
Point,
|
|
serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point",
|
|
)
|
|
|
|
with self.assertRaisesRegex(ValueError, "Unknown protocol"):
|
|
python_pytree.treespec_dumps(spec, -1)
|
|
|
|
serialized_spec = python_pytree.treespec_dumps(spec)
|
|
_, data = json.loads(serialized_spec)
|
|
bad_protocol_serialized_spec = json.dumps((-1, data))
|
|
|
|
with self.assertRaisesRegex(ValueError, "Unknown protocol"):
|
|
python_pytree.treespec_loads(bad_protocol_serialized_spec)
|
|
|
|
def test_saved_serialized(self):
|
|
# python_pytree.tree_structure(OrderedDict([(1, (0, 1)), (2, 2), (3, {4: 3, 5: 4, 6: 5})]))
|
|
complicated_spec = python_pytree.TreeSpec(
|
|
OrderedDict,
|
|
[1, 2, 3],
|
|
[
|
|
python_pytree.TreeSpec(tuple, None, [python_leafspec, python_leafspec]),
|
|
python_leafspec,
|
|
python_pytree.TreeSpec(
|
|
dict,
|
|
[4, 5, 6],
|
|
[python_leafspec, python_leafspec, python_leafspec],
|
|
),
|
|
],
|
|
)
|
|
# Ensure that the spec is valid
|
|
self.assertEqual(
|
|
complicated_spec,
|
|
python_pytree.tree_structure(
|
|
python_pytree.tree_unflatten(
|
|
[0] * complicated_spec.num_leaves, complicated_spec
|
|
)
|
|
),
|
|
)
|
|
|
|
serialized_spec = python_pytree.treespec_dumps(complicated_spec)
|
|
saved_spec = (
|
|
'[1, {"type": "collections.OrderedDict", "context": "[1, 2, 3]", '
|
|
'"children_spec": [{"type": "builtins.tuple", "context": "null", '
|
|
'"children_spec": [{"type": null, "context": null, '
|
|
'"children_spec": []}, {"type": null, "context": null, '
|
|
'"children_spec": []}]}, {"type": null, "context": null, '
|
|
'"children_spec": []}, {"type": "builtins.dict", "context": '
|
|
'"[4, 5, 6]", "children_spec": [{"type": null, "context": null, '
|
|
'"children_spec": []}, {"type": null, "context": null, "children_spec": '
|
|
'[]}, {"type": null, "context": null, "children_spec": []}]}]}]'
|
|
)
|
|
self.assertEqual(serialized_spec, saved_spec)
|
|
self.assertEqual(complicated_spec, python_pytree.treespec_loads(saved_spec))
|
|
|
|
def test_tree_map_with_path(self):
|
|
tree = [{i: i for i in range(10)}]
|
|
all_zeros = python_pytree.tree_map_with_path(
|
|
lambda kp, val: val - kp[1].key + kp[0].idx, tree
|
|
)
|
|
self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)])
|
|
|
|
def test_dataclass(self):
|
|
@dataclass
|
|
class Data:
|
|
a: torch.Tensor
|
|
b: str = "moo"
|
|
c: Optional[str] = None
|
|
d: str = field(init=False, default="")
|
|
|
|
python_pytree.register_dataclass(Data)
|
|
old_data = Data(torch.tensor(3), "b", "c")
|
|
old_data.d = "d"
|
|
new_data = python_pytree.tree_map(lambda x: x, old_data)
|
|
self.assertEqual(new_data.a, torch.tensor(3))
|
|
self.assertEqual(new_data.b, "b")
|
|
self.assertEqual(new_data.c, "c")
|
|
self.assertEqual(new_data.d, "")
|
|
python_pytree._deregister_pytree_node(Data)
|
|
|
|
with self.assertRaisesRegex(ValueError, "Missing fields"):
|
|
python_pytree.register_dataclass(Data, field_names=["a", "b"])
|
|
|
|
with self.assertRaisesRegex(ValueError, "Unexpected fields"):
|
|
python_pytree.register_dataclass(Data, field_names=["a", "b", "e"])
|
|
|
|
with self.assertRaisesRegex(ValueError, "Unexpected fields"):
|
|
python_pytree.register_dataclass(Data, field_names=["a", "b", "c", "d"])
|
|
|
|
python_pytree.register_dataclass(
|
|
Data, field_names=["a"], drop_field_names=["b", "c"]
|
|
)
|
|
old_data = Data(torch.tensor(3), "b", "c")
|
|
new_data = python_pytree.tree_map(lambda x: x, old_data)
|
|
self.assertEqual(new_data.a, torch.tensor(3))
|
|
self.assertEqual(new_data.b, "moo")
|
|
self.assertEqual(new_data.c, None)
|
|
python_pytree._deregister_pytree_node(Data)
|
|
|
|
def test_register_dataclass_class(self):
|
|
class CustomClass:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
with self.assertRaisesRegex(ValueError, "field_names must be specified"):
|
|
python_pytree.register_dataclass(CustomClass)
|
|
|
|
python_pytree.register_dataclass(CustomClass, field_names=["x", "y"])
|
|
c = CustomClass(torch.tensor(0), torch.tensor(1))
|
|
mapped = python_pytree.tree_map(lambda x: x + 1, c)
|
|
self.assertEqual(mapped.x, torch.tensor(1))
|
|
self.assertEqual(mapped.y, torch.tensor(2))
|
|
|
|
def test_constant(self):
|
|
# Either use `frozen=True` or `unsafe_hash=True` so we have a
|
|
# non-default `__hash__`.
|
|
@dataclass(unsafe_hash=True)
|
|
class Config:
|
|
norm: str
|
|
|
|
python_pytree.register_constant(Config)
|
|
|
|
config = Config("l1")
|
|
elements, spec = python_pytree.tree_flatten(config)
|
|
self.assertEqual(elements, [])
|
|
self.assertEqual(spec.context.value, config)
|
|
|
|
def test_constant_default_eq_error(self):
|
|
class Config:
|
|
def __init__(self, norm: str):
|
|
self.norm = norm
|
|
|
|
try:
|
|
python_pytree.register_constant(Config)
|
|
self.assertFalse(True) # must raise error before this
|
|
except TypeError as e:
|
|
msg = "register_constant(cls) expects `cls` to have a non-default `__eq__` implementation."
|
|
self.assertIn(msg, str(e))
|
|
|
|
def test_constant_default_hash_error(self):
|
|
class Config:
|
|
def __init__(self, norm: str):
|
|
self.norm = norm
|
|
|
|
def __eq__(self, other):
|
|
return self.norm == other.norm
|
|
|
|
try:
|
|
python_pytree.register_constant(Config)
|
|
self.assertFalse(True) # must raise error before this
|
|
except TypeError as e:
|
|
msg = "register_constant(cls) expects `cls` to have a non-default `__hash__` implementation."
|
|
self.assertIn(msg, str(e))
|
|
|
|
def test_tree_map_with_path_multiple_trees(self):
|
|
@dataclass
|
|
class ACustomPytree:
|
|
x: Any
|
|
y: Any
|
|
z: Any
|
|
|
|
tree1 = [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5]
|
|
tree2 = [ACustomPytree(x=2, y={"cin": [2, 2, 2], "bar": 2}, z="leaf"), 2]
|
|
|
|
python_pytree.register_pytree_node(
|
|
ACustomPytree,
|
|
flatten_fn=lambda f: ([f.x, f.y], f.z),
|
|
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
|
|
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
|
|
)
|
|
from_two_trees = python_pytree.tree_map_with_path(
|
|
lambda kp, a, b: a + b, tree1, tree2
|
|
)
|
|
from_one_tree = python_pytree.tree_map(lambda a: a + 2, tree1)
|
|
self.assertEqual(from_two_trees, from_one_tree)
|
|
|
|
def test_tree_flatten_with_path_is_leaf(self):
|
|
leaf_dict = {"foo": [(3)]}
|
|
tree = (["hello", [1, 2], leaf_dict],)
|
|
key_leaves, _ = python_pytree.tree_flatten_with_path(
|
|
tree, is_leaf=lambda x: isinstance(x, dict)
|
|
)
|
|
self.assertTrue(key_leaves[-1][1] is leaf_dict)
|
|
|
|
def test_tree_flatten_with_path_roundtrip(self):
|
|
class ANamedTuple(NamedTuple):
|
|
x: torch.Tensor
|
|
y: int
|
|
z: str
|
|
|
|
@dataclass
|
|
class ACustomPytree:
|
|
x: Any
|
|
y: Any
|
|
z: Any
|
|
|
|
python_pytree.register_pytree_node(
|
|
ACustomPytree,
|
|
flatten_fn=lambda f: ([f.x, f.y], f.z),
|
|
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
|
|
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
|
|
)
|
|
|
|
SOME_PYTREES = [
|
|
(None,),
|
|
["hello", [1, 2], {"foo": [(3)]}],
|
|
[ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
|
|
[ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
|
|
]
|
|
for tree in SOME_PYTREES:
|
|
key_leaves, spec = python_pytree.tree_flatten_with_path(tree)
|
|
actual = python_pytree.tree_unflatten(
|
|
[leaf for _, leaf in key_leaves], spec
|
|
)
|
|
self.assertEqual(actual, tree)
|
|
|
|
def test_tree_leaves_with_path(self):
|
|
class ANamedTuple(NamedTuple):
|
|
x: torch.Tensor
|
|
y: int
|
|
z: str
|
|
|
|
@dataclass
|
|
class ACustomPytree:
|
|
x: Any
|
|
y: Any
|
|
z: Any
|
|
|
|
python_pytree.register_pytree_node(
|
|
ACustomPytree,
|
|
flatten_fn=lambda f: ([f.x, f.y], f.z),
|
|
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
|
|
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
|
|
)
|
|
|
|
SOME_PYTREES = [
|
|
(None,),
|
|
["hello", [1, 2], {"foo": [(3)]}],
|
|
[ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
|
|
[ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
|
|
]
|
|
for tree in SOME_PYTREES:
|
|
flat_out, _ = python_pytree.tree_flatten_with_path(tree)
|
|
leaves_out = python_pytree.tree_leaves_with_path(tree)
|
|
self.assertEqual(flat_out, leaves_out)
|
|
|
|
def test_key_str(self):
|
|
class ANamedTuple(NamedTuple):
|
|
x: str
|
|
y: int
|
|
|
|
tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
|
|
flat, _ = python_pytree.tree_flatten_with_path(tree)
|
|
paths = [f"{python_pytree.keystr(kp)}: {val}" for kp, val in flat]
|
|
self.assertEqual(
|
|
paths,
|
|
[
|
|
"[0][0]: hello",
|
|
"[0][1][0]: 1",
|
|
"[0][1][1]: 2",
|
|
"[0][2]['foo'][0]: 3",
|
|
"[0][2]['bar'][0].x: baz",
|
|
"[0][2]['bar'][0].y: 10",
|
|
],
|
|
)
|
|
|
|
def test_flatten_flatten_with_key_consistency(self):
|
|
"""Check that flatten and flatten_with_key produces consistent leaves/context."""
|
|
reg = python_pytree.SUPPORTED_NODES
|
|
|
|
EXAMPLE_TREE = {
|
|
list: [1, 2, 3],
|
|
tuple: (1, 2, 3),
|
|
dict: {"foo": 1, "bar": 2},
|
|
namedtuple: namedtuple("ANamedTuple", ["x", "y"])(1, 2),
|
|
OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
|
|
defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
|
|
deque: deque([1, 2, 3]),
|
|
torch.Size: torch.Size([1, 2, 3]),
|
|
immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
|
|
immutable_list: immutable_list([1, 2, 3]),
|
|
}
|
|
|
|
for typ in reg:
|
|
example = EXAMPLE_TREE.get(typ)
|
|
if example is None:
|
|
continue
|
|
flat_with_path, spec1 = python_pytree.tree_flatten_with_path(example)
|
|
flat, spec2 = python_pytree.tree_flatten(example)
|
|
|
|
self.assertEqual(flat, [x[1] for x in flat_with_path])
|
|
self.assertEqual(spec1, spec2)
|
|
|
|
def test_key_access(self):
|
|
class ANamedTuple(NamedTuple):
|
|
x: str
|
|
y: int
|
|
|
|
tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
|
|
flat, _ = python_pytree.tree_flatten_with_path(tree)
|
|
for kp, val in flat:
|
|
self.assertEqual(python_pytree.key_get(tree, kp), val)
|
|
|
|
|
|
class TestCxxPytree(TestCase):
|
|
def setUp(self):
|
|
if IS_FBCODE:
|
|
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
|
|
|
|
def test_treespec_equality(self):
|
|
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
|
|
|
|
def test_treespec_repr(self):
|
|
# Check that it looks sane
|
|
tree = (0, [0, 0, [0]])
|
|
spec = cxx_pytree.tree_structure(tree)
|
|
self.assertEqual(
|
|
repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')"
|
|
)
|
|
|
|
@parametrize(
|
|
"spec",
|
|
[
|
|
cxx_pytree.tree_structure([]),
|
|
cxx_pytree.tree_structure(()),
|
|
cxx_pytree.tree_structure({}),
|
|
cxx_pytree.tree_structure([0]),
|
|
cxx_pytree.tree_structure([0, 1]),
|
|
cxx_pytree.tree_structure((0, 1, 2)),
|
|
cxx_pytree.tree_structure({"a": 0, "b": 1, "c": 2}),
|
|
cxx_pytree.tree_structure(
|
|
OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
|
|
),
|
|
cxx_pytree.tree_structure([(0, 1, [2, 3])]),
|
|
cxx_pytree.tree_structure(
|
|
defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}})
|
|
),
|
|
],
|
|
)
|
|
def test_pytree_serialize(self, spec):
|
|
self.assertEqual(
|
|
spec,
|
|
cxx_pytree.tree_structure(
|
|
cxx_pytree.tree_unflatten([0] * spec.num_leaves, spec)
|
|
),
|
|
)
|
|
|
|
serialized_spec = cxx_pytree.treespec_dumps(spec)
|
|
self.assertIsInstance(serialized_spec, str)
|
|
self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec))
|
|
|
|
def test_pytree_serialize_namedtuple(self):
|
|
python_pytree._register_namedtuple(
|
|
GlobalPoint,
|
|
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.GlobalPoint",
|
|
)
|
|
spec = cxx_pytree.tree_structure(GlobalPoint(0, 1))
|
|
|
|
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
|
|
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
|
|
|
|
LocalPoint = namedtuple("LocalPoint", ["x", "y"])
|
|
python_pytree._register_namedtuple(
|
|
LocalPoint,
|
|
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.LocalPoint",
|
|
)
|
|
spec = cxx_pytree.tree_structure(LocalPoint(0, 1))
|
|
|
|
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
|
|
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
|
|
|
|
def test_pytree_custom_type_serialize(self):
|
|
spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1))
|
|
serialized_spec = cxx_pytree.treespec_dumps(spec)
|
|
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
|
|
self.assertEqual(roundtrip_spec, spec)
|
|
|
|
class LocalDummyType:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
cxx_pytree.register_pytree_node(
|
|
LocalDummyType,
|
|
lambda dummy: ([dummy.x, dummy.y], None),
|
|
lambda xs, _: LocalDummyType(*xs),
|
|
serialized_type_name="LocalDummyType",
|
|
)
|
|
spec = cxx_pytree.tree_structure(LocalDummyType(0, 1))
|
|
serialized_spec = cxx_pytree.treespec_dumps(spec)
|
|
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
|
|
self.assertEqual(roundtrip_spec, spec)
|
|
|
|
|
|
instantiate_parametrized_tests(TestGenericPytree)
|
|
instantiate_parametrized_tests(TestPythonPytree)
|
|
instantiate_parametrized_tests(TestCxxPytree)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|