diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index cd7fe2e88349..1a9d8e8155e4 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -87,12 +87,15 @@ from torch.testing._internal.common_methods_invocations import ( ) from torch.testing._internal.common_utils import ( freeze_rng_state, + instantiate_parametrized_tests, IS_FBCODE, + parametrize, scoped_load_inline, set_default_dtype, skipIfHpu, skipIfNNModuleInlined, skipIfWindows, + subtest, TEST_HPU, TEST_XPU, wrapDeterministicFlagAPITest, @@ -101,11 +104,21 @@ from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.logging_utils import logs_to_string +pytree_modules = { + "python": python_pytree, +} if python_pytree._cxx_pytree_dynamo_traceable: import torch.utils._cxx_pytree as cxx_pytree + + pytree_modules["cxx"] = cxx_pytree else: cxx_pytree = None +parametrize_pytree_module = parametrize( + "pytree", + [subtest(module, name=name) for name, module in pytree_modules.items()], +) + MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"]) T = typing.TypeVar("T") @@ -9107,71 +9120,6 @@ def ___make_guard_fn(): opt = torch.compile(fn, backend="eager") opt() - def test_tracing_py_tree(self): - def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) - res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) - - xs = [torch.tensor(i) for i in range(3)] - - counter = CompileCounter() - torch.compile(fn, backend=counter, fullgraph=True)(xs) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 3) - - def test_tracing_nested_py_tree(self): - def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) - res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) - - xs = [torch.tensor(i) for i in range(3)] - xsl = [xs, xs, xs, xs] - - counter = CompileCounter() - comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) - real_out = fn(xsl) - self.assertEqual(comp_out, real_out) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 12) - - def test_tracing_nested_py_tree_tuples(self): - def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) - res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) - - xs = [torch.tensor(i) for i in range(3)] - xsl = (xs, xs, xs, xs) - - counter = CompileCounter() - comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) - real_out = fn(xsl) - self.assertEqual(comp_out, real_out) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 12) - - def test_tracing_nested_py_tree_dicts(self): - def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) - res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) - - xs = [torch.tensor(i) for i in range(3)] - xsl = { - "a": xs, - "b": xs, - "c": xs, - } - - counter = CompileCounter() - comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) - real_out = fn(xsl) - self.assertEqual(comp_out, real_out) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 9) - def test_dynamic_one_hot(self): def fn(x): x = x + 1 @@ -9188,28 +9136,6 @@ def ___make_guard_fn(): self.assertEqual(counter.frame_count, 2) self.assertEqual(counter.op_count, 2) - def test_tracing_nested_py_tree_mixed_all(self): - def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) - res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) - - xs = [torch.tensor(i) for i in range(3)] - xsa = (xs, xs) - xsb = {"aa": xsa, "ab": xs} - xsl = { - "a": xs, - "b": xsa, - "c": xsb, - } - - counter = CompileCounter() - comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) - real_out = fn(xsl) - self.assertEqual(comp_out, real_out) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 18) - def test_any_all_symnode(self): cnt = CompileCounter() @@ -9236,46 +9162,6 @@ def ___make_guard_fn(): self.assertEqual(fn(y3), y3 - 3) self.assertEqual(cnt.frame_count, 2) - def test_tracing_py_tree_tensor_subclass(self): - from torch.testing._internal.two_tensor import TwoTensor - from torch.utils.checkpoint import checkpoint - - def fn(xs): - nested_xs = [[xs]] - flat_xs, spec = python_pytree.tree_flatten(xs) - return flat_xs[0].clone() - - # use checkpoint to trigger a "sourceless" tensor subclass - def checkpoint_fn(xs): - return checkpoint(fn, xs, use_reentrant=True) - - xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2)) - - counter = CompileCounter() - torch.compile(checkpoint_fn, backend=counter, fullgraph=True)(xs) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 2) - - def test_tracing_tree_map_only(self): - def fn(xs): - def mapper(x): - return x.clone() - - y = python_pytree.tree_map_only(torch.Tensor, mapper, xs) - return y - - xs = [torch.tensor(i) for i in range(3)] + ["hi"] - xsa = (xs, xs) - xsb = {"aa": xsa, "ab": xs} - - counter = CompileCounter() - comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsb) - real_out = fn(xsb) - - self.assertEqual(comp_out, real_out) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 9) - @torch._dynamo.config.patch( capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True ) @@ -10718,139 +10604,6 @@ def ___make_guard_fn(): expected = fn(*inps) self.assertEqual(actual, expected) - def test_pytree_tree_leaves(self): - implementations = [("python", python_pytree)] - if cxx_pytree is not None: - implementations.append(("cxx", cxx_pytree)) - - for name, module in implementations: - with self.subTest(f"pytree implement: {name}"): - - def fn(x): - tree = { - "a": [x, x - 1], - "b": x + 2, - "c": ( - x, - 3.0, - collections.deque([0.0, -x, 1, 2], maxlen=3), - ), - "d": collections.OrderedDict( - { - "e": torch.return_types.qr((2 * x, None)), - "f": MyTuple(x, x + 1, torch.zeros(4, 3)), - }, - ), - } - leaves = module.tree_leaves(tree) - return leaves - - x = torch.randn(3, 2) - expected = fn(x) - fn_opt = torch.compile(fullgraph=True)(fn) - actual = fn_opt(x) - - self.assertEqual(actual, expected) - - def test_pytree_tree_flatten_unflatten(self): - implementations = [("python", python_pytree)] - if cxx_pytree is not None: - implementations.append(("cxx", cxx_pytree)) - - for name, module in implementations: - with self.subTest(f"pytree implement: {name}"): - - def fn(x, y): - tree = { - "a": [x, x - 1], - "b": x + 2, - "c": ( - x, - 3.0, - collections.deque([0.0, -x, 1, 2], maxlen=3), - ), - "d": collections.OrderedDict( - { - "e": torch.return_types.qr((2 * x, None)), - "f": MyTuple(x, x + 1, torch.zeros(4, 3)), - }, - ), - } - leaves, treespec = module.tree_flatten(tree) - new_leaves = [ - x - 1, - y, - x * y, - 3.0, - y - 2, - 1, - torch.zeros(2, 2), - 2 * y, - -y, - x + y, - x - y, - torch.ones(3, 2), - 1, - ] - new_tree = module.tree_unflatten(new_leaves, treespec) - return leaves, new_tree - - x = torch.randn(3, 2) - y = torch.randn(3, 2) - expected = fn(x, y) - fn_opt = torch.compile(fullgraph=True)(fn) - actual = fn_opt(x, y) - - self.assertEqual(actual, expected) - - def test_pytree_tree_map(self): - implementations = [("python", python_pytree)] - if cxx_pytree is not None: - implementations.append(("cxx", cxx_pytree)) - - for name, module in implementations: - with self.subTest(f"pytree implement: {name}"): - - def fn(x, y): - tree1 = { - "a": [x, x - 1], - "b": x + 2, - "c": ( - x, - 3.0, - collections.deque([0.0, -x, 1, 2], maxlen=3), - ), - "d": collections.OrderedDict( - { - "e": torch.return_types.qr((2 * x, None)), - "f": MyTuple(x, x + 1, torch.zeros(4, 3)), - }, - ), - } - tree2 = collections.OrderedDict( - [ - ("c", (y, 3.0, collections.deque([1, -y, 10.0]))), - ("a", [y, y + 1]), - ("b", y + 2), - ( - "d", - { - "f": MyTuple(torch.ones(4, 3), -y, y + 1), - "e": torch.return_types.qr((2 * y, None)), - }, - ), - ], - ) - return module.tree_map(lambda u, v: (u, v), tree1, tree2) - - x = torch.randn(3, 2) - y = torch.randn(3, 2) - expected = fn(x, y) - fn_opt = torch.compile(fullgraph=True)(fn) - actual = fn_opt(x, y) - - self.assertEqual(actual, expected) - def test_shape_env_no_recording(self): main = ShapeEnv(should_record_events=False) @@ -12886,6 +12639,257 @@ fn self.assertRaises(Unsupported, f, "1 + j") +class MiscTestsPyTree(torch._inductor.test_case.TestCase): + @parametrize_pytree_module + def test_tracing_pytree(self, pytree): + def fn(xs): + flat_xs, spec = pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + + counter = CompileCounter() + torch.compile(fn, backend=counter, fullgraph=True)(xs) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 3) + + @parametrize_pytree_module + def test_tracing_nested_pytree(self, pytree): + def fn(xs): + flat_xs, spec = pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + xsl = [xs, xs, xs, xs] + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) + real_out = fn(xsl) + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 12) + + @parametrize_pytree_module + def test_tracing_nested_tuples(self, pytree): + def fn(xs): + flat_xs, spec = pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + xsl = (xs, xs, xs, xs) + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) + real_out = fn(xsl) + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 12) + + @parametrize_pytree_module + def test_tracing_nested_dicts(self, pytree): + def fn(xs): + flat_xs, spec = pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + xsl = { + "a": xs, + "b": xs, + "c": xs, + } + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) + real_out = fn(xsl) + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 9) + + @parametrize_pytree_module + def test_tracing_nested_mixed_all(self, pytree): + def fn(xs): + flat_xs, spec = pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + xsa = (xs, xs) + xsb = {"aa": xsa, "ab": xs} + xsl = { + "a": xs, + "b": xsa, + "c": xsb, + } + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) + real_out = fn(xsl) + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 18) + + @parametrize_pytree_module + def test_tracing_nested_tensor_subclass(self, pytree): + from torch.testing._internal.two_tensor import TwoTensor + from torch.utils.checkpoint import checkpoint + + def fn(xs): + nested_xs = [[xs]] + flat_xs, spec = pytree.tree_flatten(xs) + return flat_xs[0].clone() + + # use checkpoint to trigger a "sourceless" tensor subclass + def checkpoint_fn(xs): + return checkpoint(fn, xs, use_reentrant=True) + + xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2)) + + counter = CompileCounter() + torch.compile(checkpoint_fn, backend=counter, fullgraph=True)(xs) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 2) + + @parametrize_pytree_module + def test_pytree_tree_leaves(self, pytree): + def fn(x): + tree = { + "a": [x, x - 1], + "b": x + 2, + "c": ( + x, + 3.0, + collections.deque([0.0, -x, 1, 2], maxlen=3), + ), + "d": collections.OrderedDict( + { + "e": torch.return_types.qr((2 * x, None)), + "f": MyTuple(x, x + 1, torch.zeros(4, 3)), + }, + ), + } + leaves = pytree.tree_leaves(tree) + return leaves + + x = torch.randn(3, 2) + expected = fn(x) + fn_opt = torch.compile(fullgraph=True)(fn) + actual = fn_opt(x) + + self.assertEqual(actual, expected) + + @parametrize_pytree_module + def test_pytree_tree_flatten_unflatten(self, pytree): + def fn(x, y): + tree = { + "a": [x, x - 1], + "b": x + 2, + "c": ( + x, + 3.0, + collections.deque([0.0, -x, 1, 2], maxlen=3), + ), + "d": collections.OrderedDict( + { + "e": torch.return_types.qr((2 * x, None)), + "f": MyTuple(x, x + 1, torch.zeros(4, 3)), + }, + ), + } + leaves, treespec = pytree.tree_flatten(tree) + new_leaves = [ + x - 1, + y, + x * y, + 3.0, + y - 2, + 1, + torch.zeros(2, 2), + 2 * y, + -y, + x + y, + x - y, + torch.ones(3, 2), + 1, + ] + new_tree = pytree.tree_unflatten(new_leaves, treespec) + return leaves, new_tree + + x = torch.randn(3, 2) + y = torch.randn(3, 2) + expected = fn(x, y) + fn_opt = torch.compile(fullgraph=True)(fn) + actual = fn_opt(x, y) + + self.assertEqual(actual, expected) + + @parametrize_pytree_module + def test_pytree_tree_map(self, pytree): + def fn(x, y): + tree1 = { + "a": [x, x - 1], + "b": x + 2, + "c": ( + x, + 3.0, + collections.deque([0.0, -x, 1, 2], maxlen=3), + ), + "d": collections.OrderedDict( + { + "e": torch.return_types.qr((2 * x, None)), + "f": MyTuple(x, x + 1, torch.zeros(4, 3)), + }, + ), + } + tree2 = collections.OrderedDict( + [ + ("c", (y, 3.0, collections.deque([1, -y, 10.0]))), + ("a", [y, y + 1]), + ("b", y + 2), + ( + "d", + { + "f": MyTuple(torch.ones(4, 3), -y, y + 1), + "e": torch.return_types.qr((2 * y, None)), + }, + ), + ], + ) + return pytree.tree_map(lambda u, v: (u, v), tree1, tree2) + + x = torch.randn(3, 2) + y = torch.randn(3, 2) + expected = fn(x, y) + fn_opt = torch.compile(fullgraph=True)(fn) + actual = fn_opt(x, y) + + self.assertEqual(actual, expected) + + @parametrize_pytree_module + def test_pytree_tree_map_only(self, pytree): + def fn(xs): + def mapper(x): + return x.clone() + + y = pytree.tree_map_only(torch.Tensor, mapper, xs) + return y + + xs = [torch.tensor(i) for i in range(3)] + ["hi"] + xsa = (xs, xs) + xsb = {"aa": xsa, "ab": xs} + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsb) + real_out = fn(xsb) + + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 9) + + class TestTracer(JitTestCase): def test_jit_save(self): def fn(): @@ -13266,10 +13270,14 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase): # RuntimeError: value cannot be converted to type at::Half without overflow +instantiate_parametrized_tests(MiscTestsPyTree) + devices = ("cuda", "hpu", "xpu") instantiate_device_type_tests( MiscTestsDevice, globals(), only_for=devices, allow_xpu=True ) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/test_pytree.py b/test/test_pytree.py index 228dec85bff6..e19f1471267c 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -14,7 +14,7 @@ from enum import auto from typing import Any, NamedTuple, Optional import torch -import torch.utils._pytree as py_pytree +import torch.utils._pytree as python_pytree from torch.fx.immutable_collections import immutable_dict, immutable_list from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -26,12 +26,24 @@ from torch.testing._internal.common_utils import ( ) -if IS_FBCODE: - # optree is not yet enabled in fbcode, so just re-test the python implementation - cxx_pytree = py_pytree -else: +pytree_modules = { + "python": python_pytree, +} +if not IS_FBCODE: import torch.utils._cxx_pytree as cxx_pytree + pytree_modules["cxx"] = cxx_pytree +else: + # optree is not yet enabled in fbcode, so just re-test the python implementation + cxx_pytree = python_pytree + + +parametrize_pytree_module = parametrize( + "pytree", + [subtest(module, name=name) for name, module in pytree_modules.items()], +) + + GlobalPoint = namedtuple("GlobalPoint", ["x", "y"]) @@ -53,26 +65,32 @@ class TestEnum(enum.Enum): A = auto() +python_leafspec = python_pytree.LeafSpec() + + class TestGenericPytree(TestCase): def test_aligned_public_apis(self): - public_apis = py_pytree.__all__ + public_apis = python_pytree.__all__ self.assertEqual(public_apis, cxx_pytree.__all__) for name in public_apis: cxx_api = getattr(cxx_pytree, name) - py_api = getattr(py_pytree, name) + python_api = getattr(python_pytree, name) - self.assertEqual(inspect.isclass(cxx_api), inspect.isclass(py_api)) - self.assertEqual(inspect.isfunction(cxx_api), inspect.isfunction(py_api)) + self.assertEqual(inspect.isclass(cxx_api), inspect.isclass(python_api)) + self.assertEqual( + inspect.isfunction(cxx_api), + inspect.isfunction(python_api), + ) if inspect.isfunction(cxx_api): cxx_signature = inspect.signature(cxx_api) - py_signature = inspect.signature(py_api) + python_signature = inspect.signature(python_api) # Check the parameter names are the same. cxx_param_names = list(cxx_signature.parameters) - py_param_names = list(py_signature.parameters) - self.assertEqual(cxx_param_names, py_param_names) + python_param_names = list(python_signature.parameters) + self.assertEqual(cxx_param_names, python_param_names) # Check the positional parameters are the same. cxx_positional_param_names = [ @@ -86,9 +104,9 @@ class TestGenericPytree(TestCase): } ) ] - py_positional_param_names = [ + python_positional_param_names = [ n - for n, p in py_signature.parameters.items() + for n, p in python_signature.parameters.items() if ( p.kind in { @@ -97,19 +115,22 @@ class TestGenericPytree(TestCase): } ) ] - self.assertEqual(cxx_positional_param_names, py_positional_param_names) + self.assertEqual( + cxx_positional_param_names, + python_positional_param_names, + ) - for py_name, py_param in py_signature.parameters.items(): - self.assertIn(py_name, cxx_signature.parameters) - cxx_param = cxx_signature.parameters[py_name] + for python_name, python_param in python_signature.parameters.items(): + self.assertIn(python_name, cxx_signature.parameters) + cxx_param = cxx_signature.parameters[python_name] # Check parameter kinds and default values are the same. - self.assertEqual(cxx_param.kind, py_param.kind) - self.assertEqual(cxx_param.default, py_param.default) + self.assertEqual(cxx_param.kind, python_param.kind) + self.assertEqual(cxx_param.default, python_param.default) # Check parameter annotations are the same. if "TreeSpec" in str(cxx_param.annotation): - self.assertIn("TreeSpec", str(py_param.annotation)) + self.assertIn("TreeSpec", str(python_param.annotation)) self.assertEqual( re.sub( r"(?:\b)([\w\.]*)TreeSpec(?:\b)", @@ -119,78 +140,66 @@ class TestGenericPytree(TestCase): re.sub( r"(?:\b)([\w\.]*)TreeSpec(?:\b)", "TreeSpec", - str(py_param.annotation), + str(python_param.annotation), ), msg=( f"C++ parameter {cxx_param} " - f"does not match Python parameter {py_param} " + f"does not match Python parameter {python_param} " f"for API `{name}`" ), ) else: self.assertEqual( cxx_param.annotation, - py_param.annotation, + python_param.annotation, msg=( f"C++ parameter {cxx_param} " - f"does not match Python parameter {py_param} " + f"does not match Python parameter {python_param} " f"for API `{name}`" ), ) - @parametrize( - "pytree_impl", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) - def test_register_pytree_node(self, pytree_impl): + @parametrize_pytree_module + def test_register_pytree_node(self, pytree): class MyDict(UserDict): pass d = MyDict(a=1, b=2, c=3) # Custom types are leaf nodes by default - values, spec = pytree_impl.tree_flatten(d) + values, spec = pytree.tree_flatten(d) self.assertEqual(values, [d]) self.assertIs(values[0], d) - self.assertEqual(d, pytree_impl.tree_unflatten(values, spec)) + self.assertEqual(d, pytree.tree_unflatten(values, spec)) self.assertTrue(spec.is_leaf()) # Register MyDict as a pytree node - pytree_impl.register_pytree_node( + pytree.register_pytree_node( MyDict, lambda d: (list(d.values()), list(d.keys())), lambda values, keys: MyDict(zip(keys, values)), ) - values, spec = pytree_impl.tree_flatten(d) + values, spec = pytree.tree_flatten(d) self.assertEqual(values, [1, 2, 3]) - self.assertEqual(d, pytree_impl.tree_unflatten(values, spec)) + self.assertEqual(d, pytree.tree_unflatten(values, spec)) # Do not allow registering the same type twice with self.assertRaisesRegex(ValueError, "already registered"): - pytree_impl.register_pytree_node( + pytree.register_pytree_node( MyDict, lambda d: (list(d.values()), list(d.keys())), lambda values, keys: MyDict(zip(keys, values)), ) - @parametrize( - "pytree_impl", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) - def test_flatten_unflatten_leaf(self, pytree_impl): + @parametrize_pytree_module + def test_flatten_unflatten_leaf(self, pytree): def run_test_with_leaf(leaf): - values, treespec = pytree_impl.tree_flatten(leaf) + values, treespec = pytree.tree_flatten(leaf) self.assertEqual(values, [leaf]) - self.assertEqual(treespec, pytree_impl.LeafSpec()) + self.assertEqual(treespec, pytree.LeafSpec()) - unflattened = pytree_impl.tree_unflatten(values, treespec) + unflattened = pytree.tree_unflatten(values, treespec) self.assertEqual(unflattened, leaf) run_test_with_leaf(1) @@ -200,16 +209,16 @@ class TestGenericPytree(TestCase): run_test_with_leaf(torch.randn(3, 3)) @parametrize( - "pytree_impl,gen_expected_fn", + "pytree,gen_expected_fn", [ subtest( ( - py_pytree, - lambda tup: py_pytree.TreeSpec( - tuple, None, [py_pytree.LeafSpec() for _ in tup] + python_pytree, + lambda tup: python_pytree.TreeSpec( + tuple, None, [python_leafspec for _ in tup] ), ), - name="py", + name="python", ), subtest( (cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))), @@ -217,15 +226,15 @@ class TestGenericPytree(TestCase): ), ], ) - def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn): + def test_flatten_unflatten_tuple(self, pytree, gen_expected_fn): def run_test(tup): expected_spec = gen_expected_fn(tup) - values, treespec = pytree_impl.tree_flatten(tup) + values, treespec = pytree.tree_flatten(tup) self.assertIsInstance(values, list) self.assertEqual(values, list(tup)) self.assertEqual(treespec, expected_spec) - unflattened = pytree_impl.tree_unflatten(values, treespec) + unflattened = pytree.tree_unflatten(values, treespec) self.assertEqual(unflattened, tup) self.assertIsInstance(unflattened, tuple) @@ -235,16 +244,16 @@ class TestGenericPytree(TestCase): run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11)) @parametrize( - "pytree_impl,gen_expected_fn", + "pytree,gen_expected_fn", [ subtest( ( - py_pytree, - lambda lst: py_pytree.TreeSpec( - list, None, [py_pytree.LeafSpec() for _ in lst] + python_pytree, + lambda lst: python_pytree.TreeSpec( + list, None, [python_leafspec for _ in lst] ), ), - name="py", + name="python", ), subtest( (cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))), @@ -252,15 +261,15 @@ class TestGenericPytree(TestCase): ), ], ) - def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn): + def test_flatten_unflatten_list(self, pytree, gen_expected_fn): def run_test(lst): expected_spec = gen_expected_fn(lst) - values, treespec = pytree_impl.tree_flatten(lst) + values, treespec = pytree.tree_flatten(lst) self.assertIsInstance(values, list) self.assertEqual(values, lst) self.assertEqual(treespec, expected_spec) - unflattened = pytree_impl.tree_unflatten(values, treespec) + unflattened = pytree.tree_unflatten(values, treespec) self.assertEqual(unflattened, lst) self.assertIsInstance(unflattened, list) @@ -269,18 +278,18 @@ class TestGenericPytree(TestCase): run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11]) @parametrize( - "pytree_impl,gen_expected_fn", + "pytree,gen_expected_fn", [ subtest( ( - py_pytree, - lambda dct: py_pytree.TreeSpec( + python_pytree, + lambda dct: python_pytree.TreeSpec( dict, list(dct.keys()), - [py_pytree.LeafSpec() for _ in dct.values()], + [python_leafspec for _ in dct.values()], ), ), - name="py", + name="python", ), subtest( ( @@ -291,15 +300,15 @@ class TestGenericPytree(TestCase): ), ], ) - def test_flatten_unflatten_dict(self, pytree_impl, gen_expected_fn): + def test_flatten_unflatten_dict(self, pytree, gen_expected_fn): def run_test(dct): expected_spec = gen_expected_fn(dct) - values, treespec = pytree_impl.tree_flatten(dct) + values, treespec = pytree.tree_flatten(dct) self.assertIsInstance(values, list) self.assertEqual(values, list(dct.values())) self.assertEqual(treespec, expected_spec) - unflattened = pytree_impl.tree_unflatten(values, treespec) + unflattened = pytree.tree_unflatten(values, treespec) self.assertEqual(unflattened, dct) self.assertIsInstance(unflattened, dict) @@ -310,18 +319,18 @@ class TestGenericPytree(TestCase): run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)}) @parametrize( - "pytree_impl,gen_expected_fn", + "pytree,gen_expected_fn", [ subtest( ( - py_pytree, - lambda odict: py_pytree.TreeSpec( + python_pytree, + lambda odict: python_pytree.TreeSpec( OrderedDict, list(odict.keys()), - [py_pytree.LeafSpec() for _ in odict.values()], + [python_leafspec for _ in odict.values()], ), ), - name="py", + name="python", ), subtest( ( @@ -334,15 +343,15 @@ class TestGenericPytree(TestCase): ), ], ) - def test_flatten_unflatten_ordereddict(self, pytree_impl, gen_expected_fn): + def test_flatten_unflatten_ordereddict(self, pytree, gen_expected_fn): def run_test(odict): expected_spec = gen_expected_fn(odict) - values, treespec = pytree_impl.tree_flatten(odict) + values, treespec = pytree.tree_flatten(odict) self.assertIsInstance(values, list) self.assertEqual(values, list(odict.values())) self.assertEqual(treespec, expected_spec) - unflattened = pytree_impl.tree_unflatten(values, treespec) + unflattened = pytree.tree_unflatten(values, treespec) self.assertEqual(unflattened, odict) self.assertIsInstance(unflattened, OrderedDict) @@ -354,18 +363,18 @@ class TestGenericPytree(TestCase): run_test(od) @parametrize( - "pytree_impl,gen_expected_fn", + "pytree,gen_expected_fn", [ subtest( ( - py_pytree, - lambda ddct: py_pytree.TreeSpec( + python_pytree, + lambda ddct: python_pytree.TreeSpec( defaultdict, [ddct.default_factory, list(ddct.keys())], - [py_pytree.LeafSpec() for _ in ddct.values()], + [python_leafspec for _ in ddct.values()], ), ), - name="py", + name="python", ), subtest( ( @@ -378,15 +387,15 @@ class TestGenericPytree(TestCase): ), ], ) - def test_flatten_unflatten_defaultdict(self, pytree_impl, gen_expected_fn): + def test_flatten_unflatten_defaultdict(self, pytree, gen_expected_fn): def run_test(ddct): expected_spec = gen_expected_fn(ddct) - values, treespec = pytree_impl.tree_flatten(ddct) + values, treespec = pytree.tree_flatten(ddct) self.assertIsInstance(values, list) self.assertEqual(values, list(ddct.values())) self.assertEqual(treespec, expected_spec) - unflattened = pytree_impl.tree_unflatten(values, treespec) + unflattened = pytree.tree_unflatten(values, treespec) self.assertEqual(unflattened, ddct) self.assertEqual(unflattened.default_factory, ddct.default_factory) self.assertIsInstance(unflattened, defaultdict) @@ -398,18 +407,16 @@ class TestGenericPytree(TestCase): run_test(defaultdict(int, {"a": 1, "b": 2, "c": torch.randn(2, 3)})) @parametrize( - "pytree_impl,gen_expected_fn", + "pytree,gen_expected_fn", [ subtest( ( - py_pytree, - lambda deq: py_pytree.TreeSpec( - deque, - deq.maxlen, - [py_pytree.LeafSpec() for _ in deq], + python_pytree, + lambda deq: python_pytree.TreeSpec( + deque, deq.maxlen, [python_leafspec for _ in deq] ), ), - name="py", + name="python", ), subtest( ( @@ -422,15 +429,15 @@ class TestGenericPytree(TestCase): ), ], ) - def test_flatten_unflatten_deque(self, pytree_impl, gen_expected_fn): + def test_flatten_unflatten_deque(self, pytree, gen_expected_fn): def run_test(deq): expected_spec = gen_expected_fn(deq) - values, treespec = pytree_impl.tree_flatten(deq) + values, treespec = pytree.tree_flatten(deq) self.assertIsInstance(values, list) self.assertEqual(values, list(deq)) self.assertEqual(treespec, expected_spec) - unflattened = pytree_impl.tree_unflatten(values, treespec) + unflattened = pytree.tree_unflatten(values, treespec) self.assertEqual(unflattened, deq) self.assertEqual(unflattened.maxlen, deq.maxlen) self.assertIsInstance(unflattened, deque) @@ -439,29 +446,23 @@ class TestGenericPytree(TestCase): run_test(deque([1.0, 2])) run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8)) - @parametrize( - "pytree_impl", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) - def test_flatten_unflatten_namedtuple(self, pytree_impl): + @parametrize_pytree_module + def test_flatten_unflatten_namedtuple(self, pytree): Point = namedtuple("Point", ["x", "y"]) def run_test(tup): - if pytree_impl is py_pytree: - expected_spec = py_pytree.TreeSpec( - namedtuple, Point, [py_pytree.LeafSpec() for _ in tup] + if pytree is python_pytree: + expected_spec = python_pytree.TreeSpec( + namedtuple, Point, [python_leafspec for _ in tup] ) else: expected_spec = cxx_pytree.tree_structure(Point(0, 1)) - values, treespec = pytree_impl.tree_flatten(tup) + values, treespec = pytree.tree_flatten(tup) self.assertIsInstance(values, list) self.assertEqual(values, list(tup)) self.assertEqual(treespec, expected_spec) - unflattened = pytree_impl.tree_unflatten(values, treespec) + unflattened = pytree.tree_unflatten(values, treespec) self.assertEqual(unflattened, tup) self.assertIsInstance(unflattened, Point) @@ -475,43 +476,31 @@ class TestGenericPytree(TestCase): subtest(torch.min, name="min"), ], ) - @parametrize( - "pytree_impl", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) - def test_flatten_unflatten_return_types(self, pytree_impl, op): + @parametrize_pytree_module + def test_flatten_unflatten_return_types(self, pytree, op): x = torch.randn(3, 3) expected = op(x, dim=0) - values, spec = pytree_impl.tree_flatten(expected) + values, spec = pytree.tree_flatten(expected) # Check that values is actually List[Tensor] and not (ReturnType(...),) for value in values: self.assertIsInstance(value, torch.Tensor) - result = pytree_impl.tree_unflatten(values, spec) + result = pytree.tree_unflatten(values, spec) self.assertEqual(type(result), type(expected)) self.assertEqual(result, expected) - @parametrize( - "pytree_impl", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) - def test_flatten_unflatten_nested(self, pytree_impl): - def run_test(pytree): - values, treespec = pytree_impl.tree_flatten(pytree) + @parametrize_pytree_module + def test_flatten_unflatten_nested(self, pytree): + def run_test(tree): + values, treespec = pytree.tree_flatten(tree) self.assertIsInstance(values, list) self.assertEqual(len(values), treespec.num_leaves) # NB: python basic data structures (dict list tuple) all have # contents equality defined on them, so the following works for them. - unflattened = pytree_impl.tree_unflatten(values, treespec) - self.assertEqual(unflattened, pytree) + unflattened = pytree.tree_unflatten(values, treespec) + self.assertEqual(unflattened, tree) cases = [ [()], @@ -523,17 +512,11 @@ class TestGenericPytree(TestCase): for case in cases: run_test(case) - @parametrize( - "pytree_impl", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) - def test_flatten_with_is_leaf(self, pytree_impl): - def run_test(pytree, one_level_leaves): - values, treespec = pytree_impl.tree_flatten( - pytree, is_leaf=lambda x: x is not pytree + @parametrize_pytree_module + def test_flatten_with_is_leaf(self, pytree): + def run_test(tree, one_level_leaves): + values, treespec = pytree.tree_flatten( + tree, is_leaf=lambda x: x is not tree ) self.assertIsInstance(values, list) self.assertEqual(len(values), treespec.num_nodes - 1) @@ -543,13 +526,13 @@ class TestGenericPytree(TestCase): self.assertEqual( treespec, - pytree_impl.tree_structure( - pytree_impl.tree_unflatten([0] * treespec.num_leaves, treespec) + pytree.tree_structure( + pytree.tree_unflatten([0] * treespec.num_leaves, treespec) ), ) - unflattened = pytree_impl.tree_unflatten(values, treespec) - self.assertEqual(unflattened, pytree) + unflattened = pytree.tree_unflatten(values, treespec) + self.assertEqual(unflattened, tree) cases = [ ([()], [()]), @@ -568,28 +551,22 @@ class TestGenericPytree(TestCase): for case in cases: run_test(*case) - @parametrize( - "pytree_impl", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) - def test_tree_map(self, pytree_impl): - def run_test(pytree): + @parametrize_pytree_module + def test_tree_map(self, pytree): + def run_test(tree): def f(x): return x * 3 - sm1 = sum(map(f, pytree_impl.tree_leaves(pytree))) - sm2 = sum(pytree_impl.tree_leaves(pytree_impl.tree_map(f, pytree))) + sm1 = sum(map(f, pytree.tree_leaves(tree))) + sm2 = sum(pytree.tree_leaves(pytree.tree_map(f, tree))) self.assertEqual(sm1, sm2) def invf(x): return x // 3 self.assertEqual( - pytree_impl.tree_map(invf, pytree_impl.tree_map(f, pytree)), - pytree, + pytree.tree_map(invf, pytree.tree_map(f, tree)), + tree, ) cases = [ @@ -602,27 +579,19 @@ class TestGenericPytree(TestCase): for case in cases: run_test(case) - @parametrize( - "pytree_impl", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) - def test_tree_map_multi_inputs(self, pytree_impl): - def run_test(pytree): + @parametrize_pytree_module + def test_tree_map_multi_inputs(self, pytree): + def run_test(tree): def f(x, y, z): return x, [y, (z, 0)] - pytree_x = pytree - pytree_y = pytree_impl.tree_map(lambda x: (x + 1,), pytree) - pytree_z = pytree_impl.tree_map(lambda x: {"a": x * 2, "b": 2}, pytree) + tree_x = tree + tree_y = pytree.tree_map(lambda x: (x + 1,), tree) + tree_z = pytree.tree_map(lambda x: {"a": x * 2, "b": 2}, tree) self.assertEqual( - pytree_impl.tree_map(f, pytree_x, pytree_y, pytree_z), - pytree_impl.tree_map( - lambda x: f(x, (x + 1,), {"a": x * 2, "b": 2}), pytree - ), + pytree.tree_map(f, tree_x, tree_y, tree_z), + pytree.tree_map(lambda x: f(x, (x + 1,), {"a": x * 2, "b": 2}), tree), ) cases = [ @@ -635,55 +604,29 @@ class TestGenericPytree(TestCase): for case in cases: run_test(case) - @parametrize( - "pytree_impl", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) - def test_tree_map_only(self, pytree_impl): + @parametrize_pytree_module + def test_tree_map_only(self, pytree): + self.assertEqual(pytree.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]) + + @parametrize_pytree_module + def test_tree_map_only_predicate_fn(self, pytree): self.assertEqual( - pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"] + pytree.tree_map_only(lambda x: x == 0, lambda x: x + 2, [0, 1]), [2, 1] ) - @parametrize( - "pytree_impl", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) - def test_tree_map_only_predicate_fn(self, pytree_impl): - self.assertEqual( - pytree_impl.tree_map_only(lambda x: x == 0, lambda x: x + 2, [0, 1]), [2, 1] - ) + @parametrize_pytree_module + def test_tree_all_any(self, pytree): + self.assertTrue(pytree.tree_all(lambda x: x % 2, [1, 3])) + self.assertFalse(pytree.tree_all(lambda x: x % 2, [0, 1])) + self.assertTrue(pytree.tree_any(lambda x: x % 2, [0, 1])) + self.assertFalse(pytree.tree_any(lambda x: x % 2, [0, 2])) + self.assertTrue(pytree.tree_all_only(int, lambda x: x % 2, [1, 3, "a"])) + self.assertFalse(pytree.tree_all_only(int, lambda x: x % 2, [0, 1, "a"])) + self.assertTrue(pytree.tree_any_only(int, lambda x: x % 2, [0, 1, "a"])) + self.assertFalse(pytree.tree_any_only(int, lambda x: x % 2, [0, 2, "a"])) - @parametrize( - "pytree_impl", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) - def test_tree_all_any(self, pytree_impl): - self.assertTrue(pytree_impl.tree_all(lambda x: x % 2, [1, 3])) - self.assertFalse(pytree_impl.tree_all(lambda x: x % 2, [0, 1])) - self.assertTrue(pytree_impl.tree_any(lambda x: x % 2, [0, 1])) - self.assertFalse(pytree_impl.tree_any(lambda x: x % 2, [0, 2])) - self.assertTrue(pytree_impl.tree_all_only(int, lambda x: x % 2, [1, 3, "a"])) - self.assertFalse(pytree_impl.tree_all_only(int, lambda x: x % 2, [0, 1, "a"])) - self.assertTrue(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 1, "a"])) - self.assertFalse(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 2, "a"])) - - @parametrize( - "pytree_impl", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) - def test_broadcast_to_and_flatten(self, pytree_impl): + @parametrize_pytree_module + def test_broadcast_to_and_flatten(self, pytree): cases = [ (1, (), []), # Same (flat) structures @@ -716,29 +659,17 @@ class TestGenericPytree(TestCase): ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]), (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]), ] - for pytree, to_pytree, expected in cases: - _, to_spec = pytree_impl.tree_flatten(to_pytree) - result = pytree_impl._broadcast_to_and_flatten(pytree, to_spec) - self.assertEqual(result, expected, msg=str([pytree, to_spec, expected])) + for tree, to_tree, expected in cases: + _, to_spec = pytree.tree_flatten(to_tree) + result = pytree._broadcast_to_and_flatten(tree, to_spec) + self.assertEqual(result, expected, msg=str([tree, to_spec, expected])) - @parametrize( - "pytree_impl", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) - def test_pytree_serialize_bad_input(self, pytree_impl): + @parametrize_pytree_module + def test_pytree_serialize_bad_input(self, pytree): with self.assertRaises(TypeError): - pytree_impl.treespec_dumps("random_blurb") + pytree.treespec_dumps("random_blurb") - @parametrize( - "pytree", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) + @parametrize_pytree_module def test_is_namedtuple(self, pytree): DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"]) @@ -779,13 +710,7 @@ class TestGenericPytree(TestCase): self.assertFalse(pytree.is_namedtuple_class(tuple)) self.assertFalse(pytree.is_namedtuple_class(list)) - @parametrize( - "pytree", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) + @parametrize_pytree_module def test_is_structseq(self, pytree): class FakeStructSeq(tuple): n_fields = 2 @@ -859,13 +784,7 @@ class TestGenericPytree(TestCase): self.assertFalse(pytree.is_namedtuple(cls)) self.assertFalse(pytree.is_namedtuple_class(cls)) - @parametrize( - "pytree", - [ - subtest(py_pytree, name="py"), - subtest(cxx_pytree, name="cxx"), - ], - ) + @parametrize_pytree_module def test_enum_treespec_roundtrip(self, pytree): data = {TestEnum.A: 5} spec = pytree.tree_structure(data) @@ -885,14 +804,14 @@ class TestPythonPytree(TestCase): with self.assertWarnsRegex( FutureWarning, "torch.utils._pytree._register_pytree_node" ): - py_pytree._register_pytree_node( + python_pytree._register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), ) with self.assertWarnsRegex(UserWarning, "already registered"): - py_pytree._register_pytree_node( + python_pytree._register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -929,28 +848,30 @@ if "optree" in sys.modules: def test_treespec_equality(self): self.assertEqual( - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), + python_pytree.LeafSpec(), + python_pytree.LeafSpec(), ) self.assertEqual( - py_pytree.TreeSpec(list, None, []), - py_pytree.TreeSpec(list, None, []), + python_pytree.TreeSpec(list, None, []), + python_pytree.TreeSpec(list, None, []), ) self.assertEqual( - py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]), - py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]), + python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]), + python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]), ) self.assertFalse( - py_pytree.TreeSpec(tuple, None, []) == py_pytree.TreeSpec(list, None, []), + python_pytree.TreeSpec(tuple, None, []) + == python_pytree.TreeSpec(list, None, []), ) self.assertTrue( - py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []), + python_pytree.TreeSpec(tuple, None, []) + != python_pytree.TreeSpec(list, None, []), ) def test_treespec_repr(self): # Check that it looks sane - pytree = (0, [0, 0, [0]]) - _, spec = py_pytree.tree_flatten(pytree) + tree = (0, [0, 0, [0]]) + spec = python_pytree.tree_structure(tree) self.assertEqual( repr(spec), ( @@ -964,113 +885,86 @@ if "optree" in sys.modules: @parametrize( "spec", [ - # py_pytree.tree_structure([]) - py_pytree.TreeSpec(list, None, []), - # py_pytree.tree_structure(()) - py_pytree.TreeSpec(tuple, None, []), - # py_pytree.tree_structure({}) - py_pytree.TreeSpec(dict, [], []), - # py_pytree.tree_structure([0]) - py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]), - # py_pytree.tree_structure([0, 1]) - py_pytree.TreeSpec( + # python_pytree.tree_structure([]) + python_pytree.TreeSpec(list, None, []), + # python_pytree.tree_structure(()) + python_pytree.TreeSpec(tuple, None, []), + # python_pytree.tree_structure({}) + python_pytree.TreeSpec(dict, [], []), + # python_pytree.tree_structure([0]) + python_pytree.TreeSpec(list, None, [python_leafspec]), + # python_pytree.tree_structure([0, 1]) + python_pytree.TreeSpec( list, None, - [ - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), - ], + [python_leafspec, python_leafspec], ), - # py_pytree.tree_structure((0, 1, 2)) - py_pytree.TreeSpec( + # python_pytree.tree_structure((0, 1, 2)) + python_pytree.TreeSpec( tuple, None, - [ - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), - ], + [python_leafspec, python_leafspec, python_leafspec], ), - # py_pytree.tree_structure({"a": 0, "b": 1, "c": 2}) - py_pytree.TreeSpec( + # python_pytree.tree_structure({"a": 0, "b": 1, "c": 2}) + python_pytree.TreeSpec( dict, ["a", "b", "c"], - [ - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), - ], + [python_leafspec, python_leafspec, python_leafspec], ), - # py_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})]) - py_pytree.TreeSpec( + # python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})]) + python_pytree.TreeSpec( OrderedDict, ["a", "b", "c"], [ - py_pytree.TreeSpec( + python_pytree.TreeSpec( tuple, None, - [ - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), - ], + [python_leafspec, python_leafspec], ), - py_pytree.LeafSpec(), - py_pytree.TreeSpec( + python_leafspec, + python_pytree.TreeSpec( dict, ["a", "b", "c"], - [ - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), - ], + [python_leafspec, python_leafspec, python_leafspec], ), ], ), - # py_pytree.tree_structure([(0, 1, [2, 3])]) - py_pytree.TreeSpec( + # python_pytree.tree_structure([(0, 1, [2, 3])]) + python_pytree.TreeSpec( list, None, [ - py_pytree.TreeSpec( + python_pytree.TreeSpec( tuple, None, [ - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), - py_pytree.TreeSpec( + python_leafspec, + python_leafspec, + python_pytree.TreeSpec( list, None, - [ - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), - ], + [python_leafspec, python_leafspec], ), ], ), ], ), - # py_pytree.tree_structure(defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}})) - py_pytree.TreeSpec( + # python_pytree.tree_structure(defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}})) + python_pytree.TreeSpec( defaultdict, [list, ["a", "b", "c"]], [ - py_pytree.TreeSpec( + python_pytree.TreeSpec( list, None, - [ - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), - ], + [python_leafspec, python_leafspec], ), - py_pytree.TreeSpec( + python_pytree.TreeSpec( list, None, - [ - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), - ], + [python_leafspec, python_leafspec], ), - py_pytree.TreeSpec(dict, [], []), + python_pytree.TreeSpec(dict, [], []), ], ), ], @@ -1079,86 +973,92 @@ if "optree" in sys.modules: # Ensure that the spec is valid self.assertEqual( spec, - py_pytree.tree_structure( - py_pytree.tree_unflatten([0] * spec.num_leaves, spec) + python_pytree.tree_structure( + python_pytree.tree_unflatten([0] * spec.num_leaves, spec) ), ) - serialized_spec = py_pytree.treespec_dumps(spec) + serialized_spec = python_pytree.treespec_dumps(spec) self.assertIsInstance(serialized_spec, str) - self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec)) + self.assertEqual(spec, python_pytree.treespec_loads(serialized_spec)) def test_pytree_serialize_defaultdict_enum(self): - spec = py_pytree.TreeSpec( + spec = python_pytree.TreeSpec( defaultdict, [list, [TestEnum.A]], [ - py_pytree.TreeSpec( + python_pytree.TreeSpec( list, None, [ - py_pytree.LeafSpec(), + python_leafspec, ], ), ], ) - serialized_spec = py_pytree.treespec_dumps(spec) + serialized_spec = python_pytree.treespec_dumps(spec) self.assertIsInstance(serialized_spec, str) def test_pytree_serialize_enum(self): - spec = py_pytree.TreeSpec(dict, TestEnum.A, [py_pytree.LeafSpec()]) + spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_leafspec]) - serialized_spec = py_pytree.treespec_dumps(spec) + serialized_spec = python_pytree.treespec_dumps(spec) self.assertIsInstance(serialized_spec, str) def test_pytree_serialize_namedtuple(self): Point1 = namedtuple("Point1", ["x", "y"]) - py_pytree._register_namedtuple( + python_pytree._register_namedtuple( Point1, serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1", ) - spec = py_pytree.tree_structure(Point1(1, 2)) + spec = python_pytree.tree_structure(Point1(1, 2)) self.assertIs(spec.type, namedtuple) - roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) + roundtrip_spec = python_pytree.treespec_loads( + python_pytree.treespec_dumps(spec) + ) self.assertEqual(spec, roundtrip_spec) class Point2(NamedTuple): x: int y: int - py_pytree._register_namedtuple( + python_pytree._register_namedtuple( Point2, serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2", ) - spec = py_pytree.tree_structure(Point2(1, 2)) + spec = python_pytree.tree_structure(Point2(1, 2)) self.assertIs(spec.type, namedtuple) - roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) + roundtrip_spec = python_pytree.treespec_loads( + python_pytree.treespec_dumps(spec) + ) self.assertEqual(spec, roundtrip_spec) class Point3(Point2): pass - py_pytree._register_namedtuple( + python_pytree._register_namedtuple( Point3, serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point3", ) - spec = py_pytree.tree_structure(Point3(1, 2)) + spec = python_pytree.tree_structure(Point3(1, 2)) self.assertIs(spec.type, namedtuple) - roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) + roundtrip_spec = python_pytree.treespec_loads( + python_pytree.treespec_dumps(spec) + ) self.assertEqual(spec, roundtrip_spec) def test_pytree_serialize_namedtuple_bad(self): DummyType = namedtuple("DummyType", ["x", "y"]) - spec = py_pytree.tree_structure(DummyType(1, 2)) + spec = python_pytree.tree_structure(DummyType(1, 2)) with self.assertRaisesRegex( NotImplementedError, "Please register using `_register_namedtuple`" ): - py_pytree.treespec_dumps(spec) + python_pytree.treespec_dumps(spec) def test_pytree_custom_type_serialize_bad(self): class DummyType: @@ -1166,17 +1066,17 @@ if "optree" in sys.modules: self.x = x self.y = y - py_pytree.register_pytree_node( + python_pytree.register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), ) - spec = py_pytree.tree_structure(DummyType(1, 2)) + spec = python_pytree.tree_structure(DummyType(1, 2)) with self.assertRaisesRegex( NotImplementedError, "No registered serialization name" ): - py_pytree.treespec_dumps(spec) + python_pytree.treespec_dumps(spec) def test_pytree_custom_type_serialize(self): class DummyType: @@ -1184,7 +1084,7 @@ if "optree" in sys.modules: self.x = x self.y = y - py_pytree.register_pytree_node( + python_pytree.register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -1192,10 +1092,10 @@ if "optree" in sys.modules: to_dumpable_context=lambda context: "moo", from_dumpable_context=lambda dumpable_context: None, ) - spec = py_pytree.tree_structure(DummyType(1, 2)) - serialized_spec = py_pytree.treespec_dumps(spec, 1) + spec = python_pytree.tree_structure(DummyType(1, 2)) + serialized_spec = python_pytree.treespec_dumps(spec, 1) self.assertIn("moo", serialized_spec) - roundtrip_spec = py_pytree.treespec_loads(serialized_spec) + roundtrip_spec = python_pytree.treespec_loads(serialized_spec) self.assertEqual(roundtrip_spec, spec) def test_pytree_serialize_register_bad(self): @@ -1207,7 +1107,7 @@ if "optree" in sys.modules: with self.assertRaisesRegex( ValueError, "Both to_dumpable_context and from_dumpable_context" ): - py_pytree.register_pytree_node( + python_pytree.register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -1221,7 +1121,7 @@ if "optree" in sys.modules: self.x = x self.y = y - py_pytree.register_pytree_node( + python_pytree.register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -1230,65 +1130,59 @@ if "optree" in sys.modules: from_dumpable_context=lambda dumpable_context: None, ) - spec = py_pytree.tree_structure(DummyType(1, 2)) + spec = python_pytree.tree_structure(DummyType(1, 2)) with self.assertRaisesRegex( TypeError, "Object of type type is not JSON serializable" ): - py_pytree.treespec_dumps(spec) + python_pytree.treespec_dumps(spec) def test_pytree_serialize_bad_protocol(self): import json Point = namedtuple("Point", ["x", "y"]) - spec = py_pytree.tree_structure(Point(1, 2)) - py_pytree._register_namedtuple( + spec = python_pytree.tree_structure(Point(1, 2)) + python_pytree._register_namedtuple( Point, serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point", ) with self.assertRaisesRegex(ValueError, "Unknown protocol"): - py_pytree.treespec_dumps(spec, -1) + python_pytree.treespec_dumps(spec, -1) - serialized_spec = py_pytree.treespec_dumps(spec) + serialized_spec = python_pytree.treespec_dumps(spec) _, data = json.loads(serialized_spec) bad_protocol_serialized_spec = json.dumps((-1, data)) with self.assertRaisesRegex(ValueError, "Unknown protocol"): - py_pytree.treespec_loads(bad_protocol_serialized_spec) + python_pytree.treespec_loads(bad_protocol_serialized_spec) def test_saved_serialized(self): - # py_pytree.tree_structure(OrderedDict([(1, (0, 1)), (2, 2), (3, {4: 3, 5: 4, 6: 5})])) - complicated_spec = py_pytree.TreeSpec( + # python_pytree.tree_structure(OrderedDict([(1, (0, 1)), (2, 2), (3, {4: 3, 5: 4, 6: 5})])) + complicated_spec = python_pytree.TreeSpec( OrderedDict, [1, 2, 3], [ - py_pytree.TreeSpec( - tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] - ), - py_pytree.LeafSpec(), - py_pytree.TreeSpec( + python_pytree.TreeSpec(tuple, None, [python_leafspec, python_leafspec]), + python_leafspec, + python_pytree.TreeSpec( dict, [4, 5, 6], - [ - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), - py_pytree.LeafSpec(), - ], + [python_leafspec, python_leafspec, python_leafspec], ), ], ) # Ensure that the spec is valid self.assertEqual( complicated_spec, - py_pytree.tree_structure( - py_pytree.tree_unflatten( + python_pytree.tree_structure( + python_pytree.tree_unflatten( [0] * complicated_spec.num_leaves, complicated_spec ) ), ) - serialized_spec = py_pytree.treespec_dumps(complicated_spec) + serialized_spec = python_pytree.treespec_dumps(complicated_spec) saved_spec = ( '[1, {"type": "collections.OrderedDict", "context": "[1, 2, 3]", ' '"children_spec": [{"type": "builtins.tuple", "context": "null", ' @@ -1301,11 +1195,11 @@ if "optree" in sys.modules: '[]}, {"type": null, "context": null, "children_spec": []}]}]}]' ) self.assertEqual(serialized_spec, saved_spec) - self.assertEqual(complicated_spec, py_pytree.treespec_loads(saved_spec)) + self.assertEqual(complicated_spec, python_pytree.treespec_loads(saved_spec)) def test_tree_map_with_path(self): tree = [{i: i for i in range(10)}] - all_zeros = py_pytree.tree_map_with_path( + all_zeros = python_pytree.tree_map_with_path( lambda kp, val: val - kp[1].key + kp[0].idx, tree ) self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)]) @@ -1318,34 +1212,34 @@ if "optree" in sys.modules: c: Optional[str] = None d: str = field(init=False, default="") - py_pytree.register_dataclass(Data) + python_pytree.register_dataclass(Data) old_data = Data(torch.tensor(3), "b", "c") old_data.d = "d" - new_data = py_pytree.tree_unflatten(*py_pytree.tree_flatten(old_data)) + new_data = python_pytree.tree_map(lambda x: x, old_data) self.assertEqual(new_data.a, torch.tensor(3)) self.assertEqual(new_data.b, "b") self.assertEqual(new_data.c, "c") self.assertEqual(new_data.d, "") - py_pytree._deregister_pytree_node(Data) + python_pytree._deregister_pytree_node(Data) with self.assertRaisesRegex(ValueError, "Missing fields"): - py_pytree.register_dataclass(Data, field_names=["a", "b"]) + python_pytree.register_dataclass(Data, field_names=["a", "b"]) with self.assertRaisesRegex(ValueError, "Unexpected fields"): - py_pytree.register_dataclass(Data, field_names=["a", "b", "e"]) + python_pytree.register_dataclass(Data, field_names=["a", "b", "e"]) with self.assertRaisesRegex(ValueError, "Unexpected fields"): - py_pytree.register_dataclass(Data, field_names=["a", "b", "c", "d"]) + python_pytree.register_dataclass(Data, field_names=["a", "b", "c", "d"]) - py_pytree.register_dataclass( + python_pytree.register_dataclass( Data, field_names=["a"], drop_field_names=["b", "c"] ) old_data = Data(torch.tensor(3), "b", "c") - new_data = py_pytree.tree_unflatten(*py_pytree.tree_flatten(old_data)) + new_data = python_pytree.tree_map(lambda x: x, old_data) self.assertEqual(new_data.a, torch.tensor(3)) self.assertEqual(new_data.b, "moo") self.assertEqual(new_data.c, None) - py_pytree._deregister_pytree_node(Data) + python_pytree._deregister_pytree_node(Data) def test_register_dataclass_class(self): class CustomClass: @@ -1354,11 +1248,11 @@ if "optree" in sys.modules: self.y = y with self.assertRaisesRegex(ValueError, "field_names must be specified"): - py_pytree.register_dataclass(CustomClass) + python_pytree.register_dataclass(CustomClass) - py_pytree.register_dataclass(CustomClass, field_names=["x", "y"]) + python_pytree.register_dataclass(CustomClass, field_names=["x", "y"]) c = CustomClass(torch.tensor(0), torch.tensor(1)) - mapped = py_pytree.tree_map(lambda x: x + 1, c) + mapped = python_pytree.tree_map(lambda x: x + 1, c) self.assertEqual(mapped.x, torch.tensor(1)) self.assertEqual(mapped.y, torch.tensor(2)) @@ -1369,10 +1263,10 @@ if "optree" in sys.modules: class Config: norm: str - py_pytree.register_constant(Config) + python_pytree.register_constant(Config) config = Config("l1") - elements, spec = py_pytree.tree_flatten(config) + elements, spec = python_pytree.tree_flatten(config) self.assertEqual(elements, []) self.assertEqual(spec.context.value, config) @@ -1382,7 +1276,7 @@ if "optree" in sys.modules: self.norm = norm try: - py_pytree.register_constant(Config) + python_pytree.register_constant(Config) self.assertFalse(True) # must raise error before this except TypeError as e: msg = "register_constant(cls) expects `cls` to have a non-default `__eq__` implementation." @@ -1397,7 +1291,7 @@ if "optree" in sys.modules: return self.norm == other.norm try: - py_pytree.register_constant(Config) + python_pytree.register_constant(Config) self.assertFalse(True) # must raise error before this except TypeError as e: msg = "register_constant(cls) expects `cls` to have a non-default `__hash__` implementation." @@ -1413,23 +1307,23 @@ if "optree" in sys.modules: tree1 = [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5] tree2 = [ACustomPytree(x=2, y={"cin": [2, 2, 2], "bar": 2}, z="leaf"), 2] - py_pytree.register_pytree_node( + python_pytree.register_pytree_node( ACustomPytree, flatten_fn=lambda f: ([f.x, f.y], f.z), unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z), flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z), ) - from_two_trees = py_pytree.tree_map_with_path( + from_two_trees = python_pytree.tree_map_with_path( lambda kp, a, b: a + b, tree1, tree2 ) - from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1) + from_one_tree = python_pytree.tree_map(lambda a: a + 2, tree1) self.assertEqual(from_two_trees, from_one_tree) def test_tree_flatten_with_path_is_leaf(self): leaf_dict = {"foo": [(3)]} - pytree = (["hello", [1, 2], leaf_dict],) - key_leaves, _ = py_pytree.tree_flatten_with_path( - pytree, is_leaf=lambda x: isinstance(x, dict) + tree = (["hello", [1, 2], leaf_dict],) + key_leaves, _ = python_pytree.tree_flatten_with_path( + tree, is_leaf=lambda x: isinstance(x, dict) ) self.assertTrue(key_leaves[-1][1] is leaf_dict) @@ -1445,7 +1339,7 @@ if "optree" in sys.modules: y: Any z: Any - py_pytree.register_pytree_node( + python_pytree.register_pytree_node( ACustomPytree, flatten_fn=lambda f: ([f.x, f.y], f.z), unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z), @@ -1458,10 +1352,12 @@ if "optree" in sys.modules: [ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")], [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5], ] - for pytree in SOME_PYTREES: - key_leaves, spec = py_pytree.tree_flatten_with_path(pytree) - actual = py_pytree.tree_unflatten([leaf for _, leaf in key_leaves], spec) - self.assertEqual(actual, pytree) + for tree in SOME_PYTREES: + key_leaves, spec = python_pytree.tree_flatten_with_path(tree) + actual = python_pytree.tree_unflatten( + [leaf for _, leaf in key_leaves], spec + ) + self.assertEqual(actual, tree) def test_tree_leaves_with_path(self): class ANamedTuple(NamedTuple): @@ -1475,7 +1371,7 @@ if "optree" in sys.modules: y: Any z: Any - py_pytree.register_pytree_node( + python_pytree.register_pytree_node( ACustomPytree, flatten_fn=lambda f: ([f.x, f.y], f.z), unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z), @@ -1488,9 +1384,9 @@ if "optree" in sys.modules: [ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")], [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5], ] - for pytree in SOME_PYTREES: - flat_out, _ = py_pytree.tree_flatten_with_path(pytree) - leaves_out = py_pytree.tree_leaves_with_path(pytree) + for tree in SOME_PYTREES: + flat_out, _ = python_pytree.tree_flatten_with_path(tree) + leaves_out = python_pytree.tree_leaves_with_path(tree) self.assertEqual(flat_out, leaves_out) def test_key_str(self): @@ -1499,8 +1395,8 @@ if "optree" in sys.modules: y: int tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],) - flat, _ = py_pytree.tree_flatten_with_path(tree) - paths = [f"{py_pytree.keystr(kp)}: {val}" for kp, val in flat] + flat, _ = python_pytree.tree_flatten_with_path(tree) + paths = [f"{python_pytree.keystr(kp)}: {val}" for kp, val in flat] self.assertEqual( paths, [ @@ -1515,7 +1411,7 @@ if "optree" in sys.modules: def test_flatten_flatten_with_key_consistency(self): """Check that flatten and flatten_with_key produces consistent leaves/context.""" - reg = py_pytree.SUPPORTED_NODES + reg = python_pytree.SUPPORTED_NODES EXAMPLE_TREE = { list: [1, 2, 3], @@ -1534,8 +1430,8 @@ if "optree" in sys.modules: example = EXAMPLE_TREE.get(typ) if example is None: continue - flat_with_path, spec1 = py_pytree.tree_flatten_with_path(example) - flat, spec2 = py_pytree.tree_flatten(example) + flat_with_path, spec1 = python_pytree.tree_flatten_with_path(example) + flat, spec2 = python_pytree.tree_flatten(example) self.assertEqual(flat, [x[1] for x in flat_with_path]) self.assertEqual(spec1, spec2) @@ -1546,9 +1442,9 @@ if "optree" in sys.modules: y: int tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],) - flat, _ = py_pytree.tree_flatten_with_path(tree) + flat, _ = python_pytree.tree_flatten_with_path(tree) for kp, val in flat: - self.assertEqual(py_pytree.key_get(tree, kp), val) + self.assertEqual(python_pytree.key_get(tree, kp), val) class TestCxxPytree(TestCase): @@ -1561,8 +1457,8 @@ class TestCxxPytree(TestCase): def test_treespec_repr(self): # Check that it looks sane - pytree = (0, [0, 0, [0]]) - _, spec = cxx_pytree.tree_flatten(pytree) + tree = (0, [0, 0, [0]]) + spec = cxx_pytree.tree_structure(tree) self.assertEqual( repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')" ) @@ -1599,7 +1495,7 @@ class TestCxxPytree(TestCase): self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec)) def test_pytree_serialize_namedtuple(self): - py_pytree._register_namedtuple( + python_pytree._register_namedtuple( GlobalPoint, serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.GlobalPoint", ) @@ -1609,7 +1505,7 @@ class TestCxxPytree(TestCase): self.assertEqual(roundtrip_spec.type._fields, spec.type._fields) LocalPoint = namedtuple("LocalPoint", ["x", "y"]) - py_pytree._register_namedtuple( + python_pytree._register_namedtuple( LocalPoint, serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.LocalPoint", )