[TreeSpec] Support enum in defaultdict (#144235)

Summary: Followup from D66269157, add support for enum in defaultdict.

Test Plan: Added unit test

Differential Revision: D67832100

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144235
Approved by: https://github.com/henrylhtsang, https://github.com/houseroad
This commit is contained in:
Henry Hu
2025-01-07 00:10:46 +00:00
committed by PyTorch MergeBot
parent c68c38c673
commit f013cfee38
2 changed files with 22 additions and 4 deletions

View File

@ -51,6 +51,10 @@ cxx_pytree.register_pytree_node(
)
class TestEnum(enum.Enum):
A = auto()
class TestGenericPytree(TestCase):
def test_aligned_public_apis(self):
public_apis = py_pytree.__all__
@ -957,10 +961,24 @@ TreeSpec(tuple, None, [*,
self.assertIsInstance(serialized_spec, str)
self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec))
def test_pytree_serialize_enum(self):
class TestEnum(enum.Enum):
A = auto()
def test_pytree_serialize_defaultdict_enum(self):
spec = py_pytree.TreeSpec(
defaultdict,
[list, [TestEnum.A]],
[
py_pytree.TreeSpec(
list,
None,
[
py_pytree.LeafSpec(),
],
),
],
)
serialized_spec = py_pytree.treespec_dumps(spec)
self.assertIsInstance(serialized_spec, str)
def test_pytree_serialize_enum(self):
spec = py_pytree.TreeSpec(dict, TestEnum.A, [py_pytree.LeafSpec()])
serialized_spec = py_pytree.treespec_dumps(spec)