mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[TreeSpec] Support enum in defaultdict (#144235)
Summary: Followup from D66269157, add support for enum in defaultdict. Test Plan: Added unit test Differential Revision: D67832100 Pull Request resolved: https://github.com/pytorch/pytorch/pull/144235 Approved by: https://github.com/henrylhtsang, https://github.com/houseroad
This commit is contained in:
committed by
PyTorch MergeBot
parent
c68c38c673
commit
f013cfee38
@ -51,6 +51,10 @@ cxx_pytree.register_pytree_node(
|
||||
)
|
||||
|
||||
|
||||
class TestEnum(enum.Enum):
|
||||
A = auto()
|
||||
|
||||
|
||||
class TestGenericPytree(TestCase):
|
||||
def test_aligned_public_apis(self):
|
||||
public_apis = py_pytree.__all__
|
||||
@ -957,10 +961,24 @@ TreeSpec(tuple, None, [*,
|
||||
self.assertIsInstance(serialized_spec, str)
|
||||
self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec))
|
||||
|
||||
def test_pytree_serialize_enum(self):
|
||||
class TestEnum(enum.Enum):
|
||||
A = auto()
|
||||
def test_pytree_serialize_defaultdict_enum(self):
|
||||
spec = py_pytree.TreeSpec(
|
||||
defaultdict,
|
||||
[list, [TestEnum.A]],
|
||||
[
|
||||
py_pytree.TreeSpec(
|
||||
list,
|
||||
None,
|
||||
[
|
||||
py_pytree.LeafSpec(),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
serialized_spec = py_pytree.treespec_dumps(spec)
|
||||
self.assertIsInstance(serialized_spec, str)
|
||||
|
||||
def test_pytree_serialize_enum(self):
|
||||
spec = py_pytree.TreeSpec(dict, TestEnum.A, [py_pytree.LeafSpec()])
|
||||
|
||||
serialized_spec = py_pytree.treespec_dumps(spec)
|
||||
|
Reference in New Issue
Block a user