mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Preserve Enum types during torch.export serialization and deserialization (#154821)
Fixes #154674 Addresses an issue where `torch.export` does not correctly preserve Python `Enum` types during the save/load round-trip. Previously, Enum inputs were serialized by value only, causing their type to be lost after deserialization. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154821 Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007, https://github.com/yushangdi, https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
27df0c56b7
commit
30293b8b5e
@ -859,6 +859,21 @@ class TestGenericPytree(TestCase):
|
||||
self.assertFalse(pytree.is_namedtuple(cls))
|
||||
self.assertFalse(pytree.is_namedtuple_class(cls))
|
||||
|
||||
@parametrize(
|
||||
"pytree",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_enum_treespec_roundtrip(self, pytree):
|
||||
data = {TestEnum.A: 5}
|
||||
spec = pytree.tree_structure(data)
|
||||
|
||||
serialized = pytree.treespec_dumps(spec)
|
||||
deserialized_spec = pytree.treespec_loads(serialized)
|
||||
self.assertEqual(spec, deserialized_spec)
|
||||
|
||||
|
||||
class TestPythonPytree(TestCase):
|
||||
def test_deprecated_register_pytree_node(self):
|
||||
|
@ -113,10 +113,14 @@ class KeyEntry(Protocol):
|
||||
|
||||
|
||||
class EnumEncoder(json.JSONEncoder):
|
||||
def default(self, obj: object) -> str:
|
||||
def default(self, obj: object) -> Union[str, dict[str, Any]]:
|
||||
if isinstance(obj, Enum):
|
||||
return obj.value # type: ignore[no-any-return]
|
||||
return super().default(obj) # type: ignore[no-any-return]
|
||||
return {
|
||||
"__enum__": True,
|
||||
"fqn": f"{obj.__class__.__module__}:{obj.__class__.__qualname__}",
|
||||
"name": obj.name,
|
||||
}
|
||||
return cast(str, super().default(obj))
|
||||
|
||||
|
||||
Context = Any
|
||||
@ -1836,6 +1840,18 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
|
||||
return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)
|
||||
|
||||
|
||||
def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]:
|
||||
if "__enum__" in obj:
|
||||
modname, _, classname = obj["fqn"].partition(":")
|
||||
mod = importlib.import_module(modname)
|
||||
enum_cls = mod
|
||||
for attr in classname.split("."):
|
||||
enum_cls = getattr(enum_cls, attr)
|
||||
enum_cls = cast(type[Enum], enum_cls)
|
||||
return enum_cls[obj["name"]]
|
||||
return obj
|
||||
|
||||
|
||||
def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
|
||||
if (
|
||||
json_schema["type"] is None
|
||||
@ -1854,7 +1870,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
|
||||
|
||||
if serialize_node_def.from_dumpable_context is None:
|
||||
try:
|
||||
context = json.loads(json_schema["context"])
|
||||
context = json.loads(json_schema["context"], object_hook=enum_object_hook)
|
||||
except TypeError as ex:
|
||||
raise TypeError(
|
||||
"Unable to deserialize context. "
|
||||
|
Reference in New Issue
Block a user