Files
pytorch/test/export/test_cpp_serdes.py
Zhengxu Chen 3ef2dfc1ba [export] Implement cpp deserializer. (#136398)
Differential Revision: D63206258

This diff introduces a mechanism to generate a json-compatible deserializer in cpp using nlohmann json (already being used by AOTI).

Why we need this? Because there will be a lot of cases where people don't want to use Python to load the graph (e.g. cpp runtime), and instead they can use this header to deserialize the JSON graph.

Every time we call update_schema.py to update the schema, the header will be auto generated and included into the source files.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136398
Approved by: https://github.com/angelayi
2024-11-14 16:34:59 +00:00

61 lines
1.4 KiB
Python

# Owner(s): ["oncall: export"]
import torch
from torch._export.serde.serialize import deserialize, serialize
try:
from . import test_export, testing
except ImportError:
import test_export # @manual=fbcode//caffe2/test:test_export-library
import testing # @manual=fbcode//caffe2/test:test_export-library
from torch.export import export
test_classes = {}
def mocked_cpp_serdes_export(*args, **kwargs):
ep = export(*args, **kwargs)
try:
payload = serialize(ep)
except Exception:
return ep
cpp_ep = torch._C._export.deserialize_exported_program(payload.exported_program)
loaded_json = torch._C._export.serialize_exported_program(cpp_ep)
payload.exported_program = loaded_json.encode()
loaded_ep = deserialize(payload)
return loaded_ep
def make_dynamic_cls(cls):
cls_prefix = "CppSerdes"
test_class = testing.make_test_cls_with_mocked_export(
cls,
cls_prefix,
"_cpp_serdes",
mocked_cpp_serdes_export,
xfail_prop="_expected_failure_cpp_serdes",
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
tests = [
test_export.TestExport,
]
for test in tests:
make_dynamic_cls(test)
del test
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()