mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Differential Revision: D63206258 This diff introduces a mechanism to generate a json-compatible deserializer in cpp using nlohmann json (already being used by AOTI). Why we need this? Because there will be a lot of cases where people don't want to use Python to load the graph (e.g. cpp runtime), and instead they can use this header to deserialize the JSON graph. Every time we call update_schema.py to update the schema, the header will be auto generated and included into the source files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136398 Approved by: https://github.com/angelayi
61 lines
1.4 KiB
Python
61 lines
1.4 KiB
Python
# Owner(s): ["oncall: export"]
|
|
|
|
|
|
import torch
|
|
from torch._export.serde.serialize import deserialize, serialize
|
|
|
|
|
|
try:
|
|
from . import test_export, testing
|
|
except ImportError:
|
|
import test_export # @manual=fbcode//caffe2/test:test_export-library
|
|
import testing # @manual=fbcode//caffe2/test:test_export-library
|
|
|
|
from torch.export import export
|
|
|
|
|
|
test_classes = {}
|
|
|
|
|
|
def mocked_cpp_serdes_export(*args, **kwargs):
|
|
ep = export(*args, **kwargs)
|
|
try:
|
|
payload = serialize(ep)
|
|
except Exception:
|
|
return ep
|
|
cpp_ep = torch._C._export.deserialize_exported_program(payload.exported_program)
|
|
loaded_json = torch._C._export.serialize_exported_program(cpp_ep)
|
|
payload.exported_program = loaded_json.encode()
|
|
loaded_ep = deserialize(payload)
|
|
return loaded_ep
|
|
|
|
|
|
def make_dynamic_cls(cls):
|
|
cls_prefix = "CppSerdes"
|
|
|
|
test_class = testing.make_test_cls_with_mocked_export(
|
|
cls,
|
|
cls_prefix,
|
|
"_cpp_serdes",
|
|
mocked_cpp_serdes_export,
|
|
xfail_prop="_expected_failure_cpp_serdes",
|
|
)
|
|
|
|
test_classes[test_class.__name__] = test_class
|
|
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
|
globals()[test_class.__name__] = test_class
|
|
test_class.__module__ = __name__
|
|
|
|
|
|
tests = [
|
|
test_export.TestExport,
|
|
]
|
|
for test in tests:
|
|
make_dynamic_cls(test)
|
|
del test
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|