Preserve Enum types during torch.export serialization and deserialization (#154821)

Fixes #154674

Addresses an issue where `torch.export` does not correctly preserve Python `Enum` types during the save/load round-trip. Previously, Enum inputs were serialized by value only, causing their type to be lost after deserialization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154821
Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007, https://github.com/yushangdi, https://github.com/angelayi
This commit is contained in:
Narek Malkhasyan
2025-06-08 17:30:28 +00:00
committed by PyTorch MergeBot
parent 27df0c56b7
commit 30293b8b5e
2 changed files with 35 additions and 4 deletions

View File

@ -859,6 +859,21 @@ class TestGenericPytree(TestCase):
self.assertFalse(pytree.is_namedtuple(cls))
self.assertFalse(pytree.is_namedtuple_class(cls))
@parametrize(
"pytree",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_enum_treespec_roundtrip(self, pytree):
data = {TestEnum.A: 5}
spec = pytree.tree_structure(data)
serialized = pytree.treespec_dumps(spec)
deserialized_spec = pytree.treespec_loads(serialized)
self.assertEqual(spec, deserialized_spec)
class TestPythonPytree(TestCase):
def test_deprecated_register_pytree_node(self):

View File

@ -113,10 +113,14 @@ class KeyEntry(Protocol):
class EnumEncoder(json.JSONEncoder):
def default(self, obj: object) -> str:
def default(self, obj: object) -> Union[str, dict[str, Any]]:
if isinstance(obj, Enum):
return obj.value # type: ignore[no-any-return]
return super().default(obj) # type: ignore[no-any-return]
return {
"__enum__": True,
"fqn": f"{obj.__class__.__module__}:{obj.__class__.__qualname__}",
"name": obj.name,
}
return cast(str, super().default(obj))
Context = Any
@ -1836,6 +1840,18 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)
def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]:
if "__enum__" in obj:
modname, _, classname = obj["fqn"].partition(":")
mod = importlib.import_module(modname)
enum_cls = mod
for attr in classname.split("."):
enum_cls = getattr(enum_cls, attr)
enum_cls = cast(type[Enum], enum_cls)
return enum_cls[obj["name"]]
return obj
def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
if (
json_schema["type"] is None
@ -1854,7 +1870,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
if serialize_node_def.from_dumpable_context is None:
try:
context = json.loads(json_schema["context"])
context = json.loads(json_schema["context"], object_hook=enum_object_hook)
except TypeError as ex:
raise TypeError(
"Unable to deserialize context. "