mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[export] Implement cpp deserializer. (#136398)
Differential Revision: D63206258 This diff introduces a mechanism to generate a json-compatible deserializer in cpp using nlohmann json (already being used by AOTI). Why we need this? Because there will be a lot of cases where people don't want to use Python to load the graph (e.g. cpp runtime), and instead they can use this header to deserialize the JSON graph. Every time we call update_schema.py to update the schema, the header will be auto generated and included into the source files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136398 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
f98c601efe
commit
3ef2dfc1ba
1
.gitattributes
vendored
1
.gitattributes
vendored
@ -5,3 +5,4 @@
|
|||||||
.github/scripts/gql_mocks.json linguist-generated=true
|
.github/scripts/gql_mocks.json linguist-generated=true
|
||||||
third_party/LICENSES_BUNDLED.txt linguist-generated=true
|
third_party/LICENSES_BUNDLED.txt linguist-generated=true
|
||||||
tools/build/bazel/requirements.txt linguist-generated=true
|
tools/build/bazel/requirements.txt linguist-generated=true
|
||||||
|
torch/csrc/utils/generated_serialization_types.h linguist-generated=true
|
||||||
|
@ -844,6 +844,7 @@ libtorch_python_core_sources = [
|
|||||||
"torch/csrc/fx/node.cpp",
|
"torch/csrc/fx/node.cpp",
|
||||||
"torch/csrc/mps/Module.cpp",
|
"torch/csrc/mps/Module.cpp",
|
||||||
"torch/csrc/mtia/Module.cpp",
|
"torch/csrc/mtia/Module.cpp",
|
||||||
|
"torch/csrc/export/pybind.cpp",
|
||||||
"torch/csrc/inductor/aoti_package/pybind.cpp",
|
"torch/csrc/inductor/aoti_package/pybind.cpp",
|
||||||
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
||||||
"torch/csrc/inductor/aoti_eager/kernel_holder.cpp",
|
"torch/csrc/inductor/aoti_eager/kernel_holder.cpp",
|
||||||
|
@ -29,7 +29,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
commit = schema_check.update_schema()
|
commit = schema_check.update_schema()
|
||||||
|
|
||||||
if os.path.exists(args.prefix + commit.path):
|
if os.path.exists(args.prefix + commit.yaml_path):
|
||||||
if commit.result["SCHEMA_VERSION"] < commit.base["SCHEMA_VERSION"]:
|
if commit.result["SCHEMA_VERSION"] < commit.base["SCHEMA_VERSION"]:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Schema version downgraded from {commit.base['SCHEMA_VERSION']} to {commit.result['SCHEMA_VERSION']}."
|
f"Schema version downgraded from {commit.base['SCHEMA_VERSION']} to {commit.result['SCHEMA_VERSION']}."
|
||||||
@ -55,17 +55,28 @@ if __name__ == "__main__":
|
|||||||
+ f"Reason: {reason}"
|
+ f"Reason: {reason}"
|
||||||
)
|
)
|
||||||
|
|
||||||
header = (
|
first_line = (
|
||||||
"# @" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py"
|
"@" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py"
|
||||||
)
|
)
|
||||||
header += f"\n# checksum<<{commit.checksum_result}>>"
|
checksum = f"checksum<<{commit.checksum_result}>>"
|
||||||
payload = dump(commit.result, Dumper=Dumper, sort_keys=False)
|
yaml_header = "# " + first_line
|
||||||
|
yaml_header += "\n# " + checksum
|
||||||
|
yaml_payload = dump(commit.result, Dumper=Dumper, sort_keys=False)
|
||||||
|
|
||||||
content = header + "\n" + payload
|
cpp_header = "// " + first_line
|
||||||
|
cpp_header += "\n// " + checksum
|
||||||
|
cpp_header += "\n// clang-format off"
|
||||||
|
cpp_header += "\n" + commit.cpp_header
|
||||||
|
cpp_header += "\n// clang-format on"
|
||||||
|
cpp_header += "\n"
|
||||||
|
|
||||||
|
yaml_content = yaml_header + "\n" + yaml_payload
|
||||||
|
|
||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
print(content)
|
print(yaml_content)
|
||||||
print("\nWill write the above schema to" + args.prefix + commit.path)
|
print("\nWill write the above schema to" + args.prefix + commit.yaml_path)
|
||||||
else:
|
else:
|
||||||
with open(args.prefix + commit.path, "w") as f:
|
with open(args.prefix + commit.yaml_path, "w") as f:
|
||||||
f.write(content)
|
f.write(yaml_content)
|
||||||
|
with open(args.prefix + commit.cpp_header_path, "w") as f:
|
||||||
|
f.write(cpp_header)
|
||||||
|
60
test/export/test_cpp_serdes.py
Normal file
60
test/export/test_cpp_serdes.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# Owner(s): ["oncall: export"]
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch._export.serde.serialize import deserialize, serialize
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from . import test_export, testing
|
||||||
|
except ImportError:
|
||||||
|
import test_export # @manual=fbcode//caffe2/test:test_export-library
|
||||||
|
import testing # @manual=fbcode//caffe2/test:test_export-library
|
||||||
|
|
||||||
|
from torch.export import export
|
||||||
|
|
||||||
|
|
||||||
|
test_classes = {}
|
||||||
|
|
||||||
|
|
||||||
|
def mocked_cpp_serdes_export(*args, **kwargs):
|
||||||
|
ep = export(*args, **kwargs)
|
||||||
|
try:
|
||||||
|
payload = serialize(ep)
|
||||||
|
except Exception:
|
||||||
|
return ep
|
||||||
|
cpp_ep = torch._C._export.deserialize_exported_program(payload.exported_program)
|
||||||
|
loaded_json = torch._C._export.serialize_exported_program(cpp_ep)
|
||||||
|
payload.exported_program = loaded_json.encode()
|
||||||
|
loaded_ep = deserialize(payload)
|
||||||
|
return loaded_ep
|
||||||
|
|
||||||
|
|
||||||
|
def make_dynamic_cls(cls):
|
||||||
|
cls_prefix = "CppSerdes"
|
||||||
|
|
||||||
|
test_class = testing.make_test_cls_with_mocked_export(
|
||||||
|
cls,
|
||||||
|
cls_prefix,
|
||||||
|
"_cpp_serdes",
|
||||||
|
mocked_cpp_serdes_export,
|
||||||
|
xfail_prop="_expected_failure_cpp_serdes",
|
||||||
|
)
|
||||||
|
|
||||||
|
test_classes[test_class.__name__] = test_class
|
||||||
|
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||||
|
globals()[test_class.__name__] = test_class
|
||||||
|
test_class.__module__ = __name__
|
||||||
|
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
test_export.TestExport,
|
||||||
|
]
|
||||||
|
for test in tests:
|
||||||
|
make_dynamic_cls(test)
|
||||||
|
del test
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
||||||
|
run_tests()
|
@ -2951,6 +2951,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||||||
export(N(), inputs, dynamic_shapes=dynamic_shapes)
|
export(N(), inputs, dynamic_shapes=dynamic_shapes)
|
||||||
|
|
||||||
@testing.expectedFailureSerDer # no unbacked bindings after deserialization?
|
@testing.expectedFailureSerDer # no unbacked bindings after deserialization?
|
||||||
|
@testing.expectedFailureCppSerDes # no unbacked bindings after deserialization?
|
||||||
@testing.expectedFailureSerDerNonStrict
|
@testing.expectedFailureSerDerNonStrict
|
||||||
def test_unbacked_bindings_for_divisible_u_symint(self):
|
def test_unbacked_bindings_for_divisible_u_symint(self):
|
||||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||||
@ -3673,6 +3674,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||||||
self._test_export_same_as_eager(kw_func, args, kwargs)
|
self._test_export_same_as_eager(kw_func, args, kwargs)
|
||||||
|
|
||||||
@testing.expectedFailureSerDer # we don't save placeholder metadata
|
@testing.expectedFailureSerDer # we don't save placeholder metadata
|
||||||
|
@testing.expectedFailureCppSerDes # we don't save placeholder metadata
|
||||||
@testing.expectedFailureSerDerNonStrict
|
@testing.expectedFailureSerDerNonStrict
|
||||||
@testing.expectedFailureNonStrict
|
@testing.expectedFailureNonStrict
|
||||||
@testing.expectedFailureTrainingIRToRunDecompNonStrict # source_fn_stack failure
|
@testing.expectedFailureTrainingIRToRunDecompNonStrict # source_fn_stack failure
|
||||||
@ -8078,6 +8080,7 @@ def forward(self, x, y):
|
|||||||
export(f, (inputs,), dynamic_shapes=dynamic_shapes)
|
export(f, (inputs,), dynamic_shapes=dynamic_shapes)
|
||||||
|
|
||||||
@testing.expectedFailureRetraceabilityNonStrict
|
@testing.expectedFailureRetraceabilityNonStrict
|
||||||
|
@testing.expectedFailureCppSerDes # dynamic shape serialization
|
||||||
def test_disable_forced_specializations_ok(self):
|
def test_disable_forced_specializations_ok(self):
|
||||||
# check that we don't force specialization, and defer to runtime asserts
|
# check that we don't force specialization, and defer to runtime asserts
|
||||||
# with allow_complex_guards_as_runtime_asserts=True to successfully export
|
# with allow_complex_guards_as_runtime_asserts=True to successfully export
|
||||||
@ -8198,6 +8201,7 @@ def forward(self, x, y):
|
|||||||
|
|
||||||
# TODO requires_grad doesn't seem to work with serialization.
|
# TODO requires_grad doesn't seem to work with serialization.
|
||||||
@testing.expectedFailureSerDer
|
@testing.expectedFailureSerDer
|
||||||
|
@testing.expectedFailureCppSerDes
|
||||||
@testing.expectedFailureSerDerNonStrict
|
@testing.expectedFailureSerDerNonStrict
|
||||||
def test_preserve_requires_grad_placeholders(self):
|
def test_preserve_requires_grad_placeholders(self):
|
||||||
class Module(torch.nn.Module):
|
class Module(torch.nn.Module):
|
||||||
@ -8536,6 +8540,7 @@ def forward(self, x, y):
|
|||||||
ep.graph_module.code
|
ep.graph_module.code
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@testing.expectedFailureCppSerDes
|
||||||
def test_slice_with_floordiv(self):
|
def test_slice_with_floordiv(self):
|
||||||
# slice operation emits runtime assert s0//2 <= s1
|
# slice operation emits runtime assert s0//2 <= s1
|
||||||
class M1(torch.nn.Module):
|
class M1(torch.nn.Module):
|
||||||
@ -9105,6 +9110,7 @@ def forward(self, x):
|
|||||||
_load_dynamic_shapes(spec, from_dict=True)
|
_load_dynamic_shapes(spec, from_dict=True)
|
||||||
|
|
||||||
@testing.expectedFailureSerDer # TODO(pianpwk): PowByNatural valuerange deserialization
|
@testing.expectedFailureSerDer # TODO(pianpwk): PowByNatural valuerange deserialization
|
||||||
|
@testing.expectedFailureCppSerDes # TODO(pianpwk): PowByNatural valuerange deserialization
|
||||||
@testing.expectedFailureSerDerNonStrict
|
@testing.expectedFailureSerDerNonStrict
|
||||||
@testing.expectedFailureRetraceabilityNonStrict
|
@testing.expectedFailureRetraceabilityNonStrict
|
||||||
def test_dim_dynamic(self):
|
def test_dim_dynamic(self):
|
||||||
|
@ -106,11 +106,13 @@ Example(s):
|
|||||||
commit = _Commit(
|
commit = _Commit(
|
||||||
result=src,
|
result=src,
|
||||||
checksum_result="",
|
checksum_result="",
|
||||||
path="",
|
yaml_path="",
|
||||||
additions=additions,
|
additions=additions,
|
||||||
subtractions=subtractions,
|
subtractions=subtractions,
|
||||||
base=dst,
|
base=dst,
|
||||||
checksum_base="",
|
checksum_base="",
|
||||||
|
cpp_header="",
|
||||||
|
cpp_header_path="",
|
||||||
)
|
)
|
||||||
next_version, _ = check(commit)
|
next_version, _ = check(commit)
|
||||||
self.assertEqual(next_version, [4, 1])
|
self.assertEqual(next_version, [4, 1])
|
||||||
@ -138,11 +140,13 @@ Example(s):
|
|||||||
commit = _Commit(
|
commit = _Commit(
|
||||||
result=src,
|
result=src,
|
||||||
checksum_result="",
|
checksum_result="",
|
||||||
path="",
|
yaml_path="",
|
||||||
additions=additions,
|
additions=additions,
|
||||||
subtractions=subtractions,
|
subtractions=subtractions,
|
||||||
base=dst,
|
base=dst,
|
||||||
checksum_base="",
|
checksum_base="",
|
||||||
|
cpp_header="",
|
||||||
|
cpp_header_path="",
|
||||||
)
|
)
|
||||||
next_version, _ = check(commit)
|
next_version, _ = check(commit)
|
||||||
self.assertEqual(next_version, [4, 1])
|
self.assertEqual(next_version, [4, 1])
|
||||||
@ -173,11 +177,13 @@ Example(s):
|
|||||||
commit = _Commit(
|
commit = _Commit(
|
||||||
result=src,
|
result=src,
|
||||||
checksum_result="",
|
checksum_result="",
|
||||||
path="",
|
yaml_path="",
|
||||||
additions=additions,
|
additions=additions,
|
||||||
subtractions=subtractions,
|
subtractions=subtractions,
|
||||||
base=dst,
|
base=dst,
|
||||||
checksum_base="",
|
checksum_base="",
|
||||||
|
cpp_header="",
|
||||||
|
cpp_header_path="",
|
||||||
)
|
)
|
||||||
next_version, _ = check(commit)
|
next_version, _ = check(commit)
|
||||||
self.assertEqual(next_version, [3, 3])
|
self.assertEqual(next_version, [3, 3])
|
||||||
@ -231,11 +237,13 @@ Example(s):
|
|||||||
commit = _Commit(
|
commit = _Commit(
|
||||||
result=src,
|
result=src,
|
||||||
checksum_result="",
|
checksum_result="",
|
||||||
path="",
|
yaml_path="",
|
||||||
additions=additions,
|
additions=additions,
|
||||||
subtractions=subtractions,
|
subtractions=subtractions,
|
||||||
base=dst,
|
base=dst,
|
||||||
checksum_base="",
|
checksum_base="",
|
||||||
|
cpp_header="",
|
||||||
|
cpp_header_path="",
|
||||||
)
|
)
|
||||||
next_version, _ = check(commit)
|
next_version, _ = check(commit)
|
||||||
self.assertEqual(next_version, [3, 3])
|
self.assertEqual(next_version, [3, 3])
|
||||||
@ -259,11 +267,13 @@ Example(s):
|
|||||||
commit = _Commit(
|
commit = _Commit(
|
||||||
result=src,
|
result=src,
|
||||||
checksum_result="",
|
checksum_result="",
|
||||||
path="",
|
yaml_path="",
|
||||||
additions=additions,
|
additions=additions,
|
||||||
subtractions=subtractions,
|
subtractions=subtractions,
|
||||||
base=dst,
|
base=dst,
|
||||||
checksum_base="",
|
checksum_base="",
|
||||||
|
cpp_header="",
|
||||||
|
cpp_header_path="",
|
||||||
)
|
)
|
||||||
next_version, _ = check(commit)
|
next_version, _ = check(commit)
|
||||||
self.assertEqual(next_version, [3, 3])
|
self.assertEqual(next_version, [3, 3])
|
||||||
@ -294,11 +304,13 @@ Example(s):
|
|||||||
commit = _Commit(
|
commit = _Commit(
|
||||||
result=src,
|
result=src,
|
||||||
checksum_result="",
|
checksum_result="",
|
||||||
path="",
|
yaml_path="",
|
||||||
additions=additions,
|
additions=additions,
|
||||||
subtractions=subtractions,
|
subtractions=subtractions,
|
||||||
base=dst,
|
base=dst,
|
||||||
checksum_base="",
|
checksum_base="",
|
||||||
|
cpp_header="",
|
||||||
|
cpp_header_path="",
|
||||||
)
|
)
|
||||||
next_version, _ = check(commit)
|
next_version, _ = check(commit)
|
||||||
self.assertEqual(next_version, [3, 3])
|
self.assertEqual(next_version, [3, 3])
|
||||||
@ -326,11 +338,13 @@ Example(s):
|
|||||||
commit = _Commit(
|
commit = _Commit(
|
||||||
result=src,
|
result=src,
|
||||||
checksum_result="",
|
checksum_result="",
|
||||||
path="",
|
yaml_path="",
|
||||||
additions=additions,
|
additions=additions,
|
||||||
subtractions=subtractions,
|
subtractions=subtractions,
|
||||||
base=dst,
|
base=dst,
|
||||||
checksum_base="",
|
checksum_base="",
|
||||||
|
cpp_header="",
|
||||||
|
cpp_header_path="",
|
||||||
)
|
)
|
||||||
next_version, _ = check(commit)
|
next_version, _ = check(commit)
|
||||||
self.assertEqual(next_version, [4, 1])
|
self.assertEqual(next_version, [4, 1])
|
||||||
|
@ -284,3 +284,8 @@ def expectedFailureSerDerPreDispatch(fn):
|
|||||||
def expectedFailurePreDispatchRunDecomp(fn):
|
def expectedFailurePreDispatchRunDecomp(fn):
|
||||||
fn._expected_failure_pre_dispatch = True
|
fn._expected_failure_pre_dispatch = True
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def expectedFailureCppSerDes(fn):
|
||||||
|
fn._expected_failure_cpp_serdes = True
|
||||||
|
return fn
|
||||||
|
@ -64,6 +64,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
|
|||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
_aoti,
|
_aoti,
|
||||||
|
_export,
|
||||||
_cpu,
|
_cpu,
|
||||||
_dynamo,
|
_dynamo,
|
||||||
_functorch,
|
_functorch,
|
||||||
|
10
torch/_C/_export.pyi
Normal file
10
torch/_C/_export.pyi
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Defined in torch/csrc/export/pybind.cpp
|
||||||
|
|
||||||
|
class CppExportedProgram: ...
|
||||||
|
|
||||||
|
def deserialize_exported_program(
|
||||||
|
serialized_program: str,
|
||||||
|
) -> CppExportedProgram: ...
|
||||||
|
def serialize_exported_program(
|
||||||
|
cpp_exported_program: CppExportedProgram,
|
||||||
|
) -> str: ...
|
@ -1,10 +1,11 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import inspect
|
||||||
import re
|
import re
|
||||||
import typing
|
import typing
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from torch._export.serde import schema
|
from torch._export.serde import schema
|
||||||
from torch._export.serde.union import _Union
|
from torch._export.serde.union import _Union
|
||||||
@ -20,43 +21,84 @@ def _check(x, msg):
|
|||||||
|
|
||||||
|
|
||||||
def _staged_schema():
|
def _staged_schema():
|
||||||
ret: Dict[str, Any] = {}
|
yaml_ret: Dict[str, Any] = {}
|
||||||
defs = {}
|
defs = {}
|
||||||
|
cpp_enum_defs: Dict[str, str] = {}
|
||||||
|
cpp_class_defs: Dict[str, str] = {}
|
||||||
|
cpp_type_decls: List[str] = []
|
||||||
|
cpp_json_defs: List[str] = []
|
||||||
|
|
||||||
def _handle_aggregate(ty):
|
def _handle_aggregate(ty) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
def dump_type(t):
|
def dump_type(t) -> Tuple[str, str]:
|
||||||
|
TYPE_MAP = {
|
||||||
|
str: "std::string",
|
||||||
|
int: "int64_t",
|
||||||
|
float: "double",
|
||||||
|
bool: "bool",
|
||||||
|
}
|
||||||
if isinstance(t, type):
|
if isinstance(t, type):
|
||||||
return t.__name__
|
if t.__name__ in cpp_enum_defs:
|
||||||
|
return t.__name__, "int64_t"
|
||||||
|
else:
|
||||||
|
return t.__name__, TYPE_MAP.get(t, t.__name__)
|
||||||
elif isinstance(t, str):
|
elif isinstance(t, str):
|
||||||
assert t in defs
|
assert t in defs
|
||||||
return t
|
assert t not in cpp_enum_defs
|
||||||
|
assert "[" not in t
|
||||||
|
return t, f"ForwardRef<{t}>"
|
||||||
elif o := typing.get_origin(t):
|
elif o := typing.get_origin(t):
|
||||||
# Lemme know if there's a better way to do this.
|
# Lemme know if there's a better way to do this.
|
||||||
if o == list:
|
if o == list:
|
||||||
head = "List"
|
yaml_head, cpp_head = "List", "std::vector"
|
||||||
elif o == dict:
|
elif o == dict:
|
||||||
head = "Dict"
|
yaml_head, cpp_head = "Dict", "std::unordered_map"
|
||||||
elif o == tuple:
|
elif o == tuple:
|
||||||
if typing.get_args(t) == ():
|
if typing.get_args(t) == ():
|
||||||
return "Tuple[()]"
|
return "Tuple[()]", "std::tuple<>"
|
||||||
head = "Tuple"
|
yaml_head, cpp_head = "Tuple", "std::tuple"
|
||||||
elif o == Union:
|
elif o == Union:
|
||||||
args = typing.get_args(t)
|
args = typing.get_args(t)
|
||||||
assert len(args) == 2 and args[1] == type(None)
|
assert len(args) == 2 and args[1] == type(None)
|
||||||
return f"Optional[{dump_type(args[0])}]"
|
yaml_type, cpp_type = dump_type(args[0])
|
||||||
|
return f"Optional[{yaml_type}]", f"std::optional<{cpp_type}>"
|
||||||
else:
|
else:
|
||||||
raise AssertionError(f"Type {t} is not supported in export schema.")
|
raise AssertionError(f"Type {t} is not supported in export schema.")
|
||||||
return (
|
yaml_arg_types, cpp_arg_types = zip(
|
||||||
f"{head}[{', '.join([dump_type(x) for x in typing.get_args(t)])}]"
|
*[dump_type(x) for x in typing.get_args(t)]
|
||||||
|
)
|
||||||
|
return (f"{yaml_head}[{', '.join(yaml_arg_types)}]"), (
|
||||||
|
f"{cpp_head}<{', '.join(cpp_arg_types)}>"
|
||||||
)
|
)
|
||||||
elif t == ():
|
elif t == ():
|
||||||
return "()"
|
return "()", ""
|
||||||
else:
|
else:
|
||||||
raise AssertionError(f"Type {t} is not supported in export schema.")
|
raise AssertionError(f"Type {t} is not supported in export schema.")
|
||||||
|
|
||||||
def dump_field(f):
|
def dump_cpp_value(v) -> str:
|
||||||
t = dump_type(f.type)
|
if v is None:
|
||||||
|
return "std::nullopt"
|
||||||
|
elif v is True:
|
||||||
|
return "true"
|
||||||
|
elif v is False:
|
||||||
|
return "false"
|
||||||
|
elif v == {}:
|
||||||
|
return "{}"
|
||||||
|
elif v == []:
|
||||||
|
return "{}"
|
||||||
|
elif v == ():
|
||||||
|
return "{}"
|
||||||
|
elif isinstance(v, str):
|
||||||
|
return f'"{v}"'
|
||||||
|
else:
|
||||||
|
raise AssertionError(
|
||||||
|
f"Default value {v} is not supported yet in export schema."
|
||||||
|
)
|
||||||
|
|
||||||
|
def dump_field(f) -> Tuple[Dict[str, Any], str, Optional[str]]:
|
||||||
|
t, cpp = dump_type(f.type)
|
||||||
ret = {"type": t}
|
ret = {"type": t}
|
||||||
|
cpp_type = cpp
|
||||||
|
cpp_default: Optional[str] = None
|
||||||
|
|
||||||
value = dataclasses.MISSING
|
value = dataclasses.MISSING
|
||||||
if f.default is not dataclasses.MISSING:
|
if f.default is not dataclasses.MISSING:
|
||||||
@ -67,24 +109,149 @@ def _staged_schema():
|
|||||||
if value is not dataclasses.MISSING:
|
if value is not dataclasses.MISSING:
|
||||||
default = str(value)
|
default = str(value)
|
||||||
ret["default"] = default
|
ret["default"] = default
|
||||||
|
cpp_default = dump_cpp_value(value)
|
||||||
|
|
||||||
if t.startswith("Optional[") and value is not None:
|
if t.startswith("Optional[") and value is not None:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
|
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
|
||||||
)
|
)
|
||||||
|
|
||||||
return ret
|
return ret, cpp_type, cpp_default
|
||||||
|
|
||||||
return {f.name: dump_field(f) for f in dataclasses.fields(ty)}
|
yaml_ret = {}
|
||||||
|
cpp_ret = {}
|
||||||
|
for f in dataclasses.fields(ty):
|
||||||
|
yaml_res, cpp_type, cpp_default = dump_field(f)
|
||||||
|
yaml_ret[f.name] = yaml_res
|
||||||
|
cpp_ret[f.name] = {"cpp_type": cpp_type, "cpp_default": cpp_default}
|
||||||
|
return yaml_ret, cpp_ret
|
||||||
|
|
||||||
def _handle_int_enum(name, ty):
|
def _handle_int_enum(name, ty):
|
||||||
ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}}
|
yaml_ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}}
|
||||||
|
cpp_enum_defs[
|
||||||
|
name
|
||||||
|
] = f"""
|
||||||
|
enum class {name} {{
|
||||||
|
{chr(10).join([f" {x.name} = {x.value}," for x in ty])}
|
||||||
|
}};
|
||||||
|
"""
|
||||||
|
|
||||||
def _handle_struct(name, ty):
|
def _handle_struct(name, ty):
|
||||||
ret[name] = {"kind": "struct", "fields": _handle_aggregate(ty)}
|
fields, cpp_fields = _handle_aggregate(ty)
|
||||||
|
yaml_ret[name] = {"kind": "struct", "fields": fields}
|
||||||
|
field_decls = "\n".join(
|
||||||
|
f" {f['cpp_type']} {name}{' = ' + f['cpp_default'] if f['cpp_default'] is not None else ''};"
|
||||||
|
for name, f in cpp_fields.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
def accessor(name, ty):
|
||||||
|
type_name = fields[name]["type"]
|
||||||
|
if type_name in cpp_enum_defs:
|
||||||
|
return f"""
|
||||||
|
{type_name} get_{name}() const {{
|
||||||
|
return static_cast<{type_name}>({name});
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
return f"""
|
||||||
|
const {ty}& get_{name}() const {{
|
||||||
|
return {name};
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
to_json_decl = f"void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t)"
|
||||||
|
to_json_def = f"""{{
|
||||||
|
{chr(10).join([f' nlohmann_json_j["{name}"] = nlohmann_json_t.{name};' for name, f in cpp_fields.items()])}
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
from_json_decl = f"void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t)"
|
||||||
|
|
||||||
|
from_json_def = f"""{{
|
||||||
|
{name} nlohmann_json_default_obj;
|
||||||
|
{chr(10).join(
|
||||||
|
[f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});'
|
||||||
|
for name, f in cpp_fields.items()])}
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
cpp_class_defs[
|
||||||
|
name
|
||||||
|
] = f"""
|
||||||
|
class {name} {{
|
||||||
|
private:
|
||||||
|
{field_decls}
|
||||||
|
|
||||||
|
public:
|
||||||
|
{"".join([accessor(name, f["cpp_type"]) for name, f in cpp_fields.items()])}
|
||||||
|
friend {to_json_decl};
|
||||||
|
friend {from_json_decl};
|
||||||
|
}};
|
||||||
|
"""
|
||||||
|
cpp_json_defs.append(f"inline {to_json_decl} {to_json_def}")
|
||||||
|
cpp_json_defs.append(f"inline {from_json_decl} {from_json_def}")
|
||||||
|
cpp_type_decls.append(f"class {name};")
|
||||||
|
|
||||||
def _handle_union(name, ty):
|
def _handle_union(name, ty):
|
||||||
ret[name] = {"kind": "union", "fields": _handle_aggregate(ty)}
|
fields, cpp_fields = _handle_aggregate(ty)
|
||||||
|
yaml_ret[name] = {"kind": "union", "fields": fields}
|
||||||
|
|
||||||
|
def accessor(name, ty, idx):
|
||||||
|
return f"""
|
||||||
|
const {ty}& get_{name}() const {{
|
||||||
|
return std::get<{idx + 1}>(variant_);
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
to_json_branches = "".join(
|
||||||
|
[
|
||||||
|
f"""
|
||||||
|
if (nlohmann_json_t.tag_ == Tag::{name.upper()}) {{
|
||||||
|
nlohmann_json_j["{name}"] = nlohmann_json_t.get_{name}();
|
||||||
|
return;
|
||||||
|
}}"""
|
||||||
|
for idx, (name, f) in enumerate(cpp_fields.items())
|
||||||
|
]
|
||||||
|
)
|
||||||
|
from_json_branches = "".join(
|
||||||
|
[
|
||||||
|
f"""
|
||||||
|
if (nlohmann_json_j.contains("{name}")) {{
|
||||||
|
nlohmann_json_t.variant_.emplace<{idx + 1}>(nlohmann_json_j.at("{name}").template get<{f["cpp_type"]}>());
|
||||||
|
nlohmann_json_t.tag_ = Tag::{name.upper()};
|
||||||
|
return;
|
||||||
|
}}"""
|
||||||
|
for idx, (name, f) in enumerate(cpp_fields.items())
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
cpp_class_defs[
|
||||||
|
name
|
||||||
|
] = f"""
|
||||||
|
class {name} {{
|
||||||
|
struct Void {{}};
|
||||||
|
|
||||||
|
public:
|
||||||
|
enum class Tag {{
|
||||||
|
{", ".join([name.upper() for name in cpp_fields])}
|
||||||
|
}};
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::variant<Void, {", ".join(f["cpp_type"] for f in cpp_fields.values())}> variant_;
|
||||||
|
Tag tag_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Tag tag() const {{
|
||||||
|
return tag_;
|
||||||
|
}}
|
||||||
|
{"".join([accessor(name, f["cpp_type"], idx) for idx, (name, f) in enumerate(cpp_fields.items())])}
|
||||||
|
friend void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t) {{
|
||||||
|
{to_json_branches}
|
||||||
|
}}
|
||||||
|
|
||||||
|
friend void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t) {{
|
||||||
|
{from_json_branches}
|
||||||
|
}}
|
||||||
|
}};
|
||||||
|
"""
|
||||||
|
cpp_type_decls.append(f"class {name};")
|
||||||
|
|
||||||
for name in dir(schema):
|
for name in dir(schema):
|
||||||
if name.startswith("_"):
|
if name.startswith("_"):
|
||||||
@ -97,11 +264,13 @@ def _staged_schema():
|
|||||||
|
|
||||||
defs[name] = value
|
defs[name] = value
|
||||||
|
|
||||||
|
class_ordering = {}
|
||||||
for name, value in defs.items():
|
for name, value in defs.items():
|
||||||
if isinstance(value, type):
|
if isinstance(value, type):
|
||||||
if issubclass(value, IntEnum):
|
if issubclass(value, IntEnum):
|
||||||
_handle_int_enum(name, value)
|
_handle_int_enum(name, value)
|
||||||
elif dataclasses.is_dataclass(value):
|
elif dataclasses.is_dataclass(value):
|
||||||
|
class_ordering[name] = inspect.findsource(value)[1]
|
||||||
if issubclass(value, _Union):
|
if issubclass(value, _Union):
|
||||||
_handle_union(name, value)
|
_handle_union(name, value)
|
||||||
else:
|
else:
|
||||||
@ -113,11 +282,103 @@ def _staged_schema():
|
|||||||
else:
|
else:
|
||||||
raise AssertionError(f"Unknown variable {name}: {value}")
|
raise AssertionError(f"Unknown variable {name}: {value}")
|
||||||
|
|
||||||
ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"])
|
yaml_ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"])
|
||||||
assert all(x > 0 for x in ret["SCHEMA_VERSION"])
|
assert all(x > 0 for x in yaml_ret["SCHEMA_VERSION"])
|
||||||
ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"]
|
yaml_ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"]
|
||||||
assert ret["TREESPEC_VERSION"] > 0
|
assert yaml_ret["TREESPEC_VERSION"] > 0
|
||||||
return ret
|
|
||||||
|
cpp_header = f"""
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <variant>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN
|
||||||
|
#define NLOHMANN_JSON_NAMESPACE_BEGIN namespace nlohmann {{
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef NLOHMANN_JSON_NAMESPACE_END
|
||||||
|
#define NLOHMANN_JSON_NAMESPACE_END }}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// https://github.com/nlohmann/json/pull/2117
|
||||||
|
NLOHMANN_JSON_NAMESPACE_BEGIN
|
||||||
|
template <typename T>
|
||||||
|
struct adl_serializer<std::optional<T>> {{
|
||||||
|
static void to_json(json& j, const std::optional<T>& opt) {{
|
||||||
|
if (opt == std::nullopt) {{
|
||||||
|
j = nullptr;
|
||||||
|
}} else {{
|
||||||
|
j = *opt; // this will call adl_serializer<T>::to_json which will
|
||||||
|
// find the free function to_json in T's namespace!
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
|
||||||
|
static void from_json(const json& j, std::optional<T>& opt) {{
|
||||||
|
if (j.is_null()) {{
|
||||||
|
opt = std::nullopt;
|
||||||
|
}} else {{
|
||||||
|
opt = j.template get<T>(); // same as above, but with
|
||||||
|
// adl_serializer<T>::from_json
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
}};
|
||||||
|
NLOHMANN_JSON_NAMESPACE_END
|
||||||
|
|
||||||
|
namespace torch {{
|
||||||
|
namespace _export {{
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class ForwardRef {{
|
||||||
|
static_assert(!std::is_reference_v<T>, "ForwardRef cannot be a reference type");
|
||||||
|
|
||||||
|
public:
|
||||||
|
ForwardRef(): ptr_(std::make_unique<T>()) {{}}
|
||||||
|
ForwardRef(ForwardRef<T>&&) = default;
|
||||||
|
ForwardRef(const ForwardRef<T>& other): ptr_(std::make_unique<T>(*other.ptr_)) {{}}
|
||||||
|
ForwardRef<T>& operator=(ForwardRef<T>&&) = default;
|
||||||
|
ForwardRef<T>& operator=(const ForwardRef<T>& other) {{
|
||||||
|
ptr_ = std::make_unique<T>(*other.ptr_);
|
||||||
|
}}
|
||||||
|
const T& operator*() const {{
|
||||||
|
return *ptr_;
|
||||||
|
}}
|
||||||
|
|
||||||
|
const T* operator->() const {{
|
||||||
|
return ptr_.get();
|
||||||
|
}}
|
||||||
|
|
||||||
|
void emplace(T&& t) {{
|
||||||
|
ptr_ = std::make_unique<T>(std::move(t));
|
||||||
|
}}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<T> ptr_;
|
||||||
|
}};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void to_json(nlohmann::json& j, const ForwardRef<T>& p) {{
|
||||||
|
j = *p;
|
||||||
|
}}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void from_json(const nlohmann::json& j, ForwardRef<T>& p) {{
|
||||||
|
p.emplace(j.template get<T>());
|
||||||
|
}}
|
||||||
|
|
||||||
|
{chr(10).join(cpp_type_decls)}
|
||||||
|
{"".join(cpp_enum_defs.values())}
|
||||||
|
{"".join(dict(sorted(cpp_class_defs.items(), key=lambda x: class_ordering[x[0]])).values())}
|
||||||
|
{chr(10).join(cpp_json_defs)}
|
||||||
|
}} // namespace _export
|
||||||
|
}} // namespace torch
|
||||||
|
"""
|
||||||
|
return yaml_ret, cpp_header
|
||||||
|
|
||||||
|
|
||||||
def _diff_schema(dst, src):
|
def _diff_schema(dst, src):
|
||||||
@ -193,11 +454,13 @@ def _hash_schema(s):
|
|||||||
class _Commit:
|
class _Commit:
|
||||||
result: Dict[str, Any]
|
result: Dict[str, Any]
|
||||||
checksum_result: str
|
checksum_result: str
|
||||||
path: str
|
yaml_path: str
|
||||||
additions: Dict[str, Any]
|
additions: Dict[str, Any]
|
||||||
subtractions: Dict[str, Any]
|
subtractions: Dict[str, Any]
|
||||||
base: Dict[str, Any]
|
base: Dict[str, Any]
|
||||||
checksum_base: Optional[str]
|
checksum_base: Optional[str]
|
||||||
|
cpp_header: str
|
||||||
|
cpp_header_path: str
|
||||||
|
|
||||||
|
|
||||||
def update_schema():
|
def update_schema():
|
||||||
@ -217,16 +480,22 @@ def update_schema():
|
|||||||
checksum_base = None
|
checksum_base = None
|
||||||
dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None}
|
dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None}
|
||||||
|
|
||||||
src = _staged_schema()
|
src, cpp_header = _staged_schema()
|
||||||
additions, subtractions = _diff_schema(dst, src)
|
additions, subtractions = _diff_schema(dst, src)
|
||||||
|
yaml_path = __package__.replace(".", "/") + "/schema.yaml"
|
||||||
|
torch_prefix = "torch/"
|
||||||
|
assert yaml_path.startswith(torch_prefix) # sanity check
|
||||||
|
|
||||||
return _Commit(
|
return _Commit(
|
||||||
result=src,
|
result=src,
|
||||||
checksum_result=_hash_schema(src),
|
checksum_result=_hash_schema(src),
|
||||||
path=__package__.replace(".", "/") + "/schema.yaml",
|
yaml_path=yaml_path,
|
||||||
additions=additions,
|
additions=additions,
|
||||||
subtractions=subtractions,
|
subtractions=subtractions,
|
||||||
base=dst,
|
base=dst,
|
||||||
checksum_base=checksum_base,
|
checksum_base=checksum_base,
|
||||||
|
cpp_header=cpp_header,
|
||||||
|
cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -697,7 +697,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def is_sym_int_arg(self, arg) -> bool:
|
def is_sym_int_arg(self, arg) -> bool:
|
||||||
return isinstance(arg, int) or (
|
return type(arg) is int or (
|
||||||
isinstance(arg, torch.fx.Node)
|
isinstance(arg, torch.fx.Node)
|
||||||
and arg.name in self.graph_state.sym_int_values
|
and arg.name in self.graph_state.sym_int_values
|
||||||
)
|
)
|
||||||
@ -770,13 +770,13 @@ class GraphModuleSerializer(metaclass=Final):
|
|||||||
# For regular FX graph, SymInt arg should be a fx.Node with
|
# For regular FX graph, SymInt arg should be a fx.Node with
|
||||||
# self.is_sym_int_arg(arg) being true
|
# self.is_sym_int_arg(arg) being true
|
||||||
return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg)))
|
return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg)))
|
||||||
elif isinstance(arg, bool):
|
elif type(arg) is bool:
|
||||||
return Argument.create(as_bool=arg)
|
return Argument.create(as_bool=arg)
|
||||||
elif isinstance(arg, str):
|
elif type(arg) is str:
|
||||||
return Argument.create(as_string=arg)
|
return Argument.create(as_string=arg)
|
||||||
elif isinstance(arg, int):
|
elif type(arg) is int:
|
||||||
return Argument.create(as_int=arg)
|
return Argument.create(as_int=arg)
|
||||||
elif isinstance(arg, float):
|
elif type(arg) is float:
|
||||||
return Argument.create(as_float=arg)
|
return Argument.create(as_float=arg)
|
||||||
elif arg is None:
|
elif arg is None:
|
||||||
return Argument.create(as_none=())
|
return Argument.create(as_none=())
|
||||||
@ -814,14 +814,13 @@ class GraphModuleSerializer(metaclass=Final):
|
|||||||
)
|
)
|
||||||
return Argument.create(as_tensors=[])
|
return Argument.create(as_tensors=[])
|
||||||
|
|
||||||
# Must check bool first, as bool is also treated as int
|
if all(type(a) is bool for a in arg):
|
||||||
if all(isinstance(a, bool) for a in arg):
|
|
||||||
return Argument.create(as_bools=list(arg))
|
return Argument.create(as_bools=list(arg))
|
||||||
elif all(isinstance(a, int) for a in arg):
|
elif all(type(a) is int for a in arg):
|
||||||
return Argument.create(as_ints=list(arg))
|
return Argument.create(as_ints=list(arg))
|
||||||
elif all(isinstance(a, float) for a in arg):
|
elif all(type(a) is float for a in arg):
|
||||||
return Argument.create(as_floats=list(arg))
|
return Argument.create(as_floats=list(arg))
|
||||||
elif all(isinstance(a, str) for a in arg):
|
elif all(type(a) is str for a in arg):
|
||||||
return Argument.create(as_strings=list(arg))
|
return Argument.create(as_strings=list(arg))
|
||||||
elif all(isinstance(a, torch.SymInt) for a in arg):
|
elif all(isinstance(a, torch.SymInt) for a in arg):
|
||||||
# This is a special branch for handling SymInt args in inductor's
|
# This is a special branch for handling SymInt args in inductor's
|
||||||
@ -837,7 +836,7 @@ class GraphModuleSerializer(metaclass=Final):
|
|||||||
for a in arg:
|
for a in arg:
|
||||||
if isinstance(a, torch.fx.Node):
|
if isinstance(a, torch.fx.Node):
|
||||||
values.append(SymIntArgument.create(as_name=a.name))
|
values.append(SymIntArgument.create(as_name=a.name))
|
||||||
elif isinstance(a, int):
|
elif type(a) is int:
|
||||||
values.append(SymIntArgument.create(as_int=a))
|
values.append(SymIntArgument.create(as_int=a))
|
||||||
return Argument.create(as_sym_ints=values)
|
return Argument.create(as_sym_ints=values)
|
||||||
elif all(self.is_sym_bool_arg(a) for a in arg):
|
elif all(self.is_sym_bool_arg(a) for a in arg):
|
||||||
@ -952,13 +951,13 @@ class GraphModuleSerializer(metaclass=Final):
|
|||||||
def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec:
|
def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec:
|
||||||
if spec.kind == ep.InputKind.USER_INPUT:
|
if spec.kind == ep.InputKind.USER_INPUT:
|
||||||
if isinstance(spec.arg, ep.ConstantArgument):
|
if isinstance(spec.arg, ep.ConstantArgument):
|
||||||
if isinstance(spec.arg.value, int):
|
if type(spec.arg.value) is int:
|
||||||
constant_spec = ConstantValue.create(as_int=spec.arg.value)
|
constant_spec = ConstantValue.create(as_int=spec.arg.value)
|
||||||
elif isinstance(spec.arg.value, bool):
|
elif type(spec.arg.value) is bool:
|
||||||
constant_spec = ConstantValue.create(as_bool=spec.arg.value)
|
constant_spec = ConstantValue.create(as_bool=spec.arg.value)
|
||||||
elif isinstance(spec.arg.value, str):
|
elif type(spec.arg.value) is str:
|
||||||
constant_spec = ConstantValue.create(as_string=spec.arg.value)
|
constant_spec = ConstantValue.create(as_string=spec.arg.value)
|
||||||
elif isinstance(spec.arg.value, float):
|
elif type(spec.arg.value) is float:
|
||||||
constant_spec = ConstantValue.create(as_float=spec.arg.value)
|
constant_spec = ConstantValue.create(as_float=spec.arg.value)
|
||||||
elif spec.arg.value is None:
|
elif spec.arg.value is None:
|
||||||
constant_spec = ConstantValue.create(as_none=())
|
constant_spec = ConstantValue.create(as_none=())
|
||||||
@ -1548,7 +1547,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||||||
|
|
||||||
return self.shape_env.create_symintnode(sym, hint=hint)
|
return self.shape_env.create_symintnode(sym, hint=hint)
|
||||||
elif s.type == "as_int":
|
elif s.type == "as_int":
|
||||||
assert isinstance(val, int)
|
assert type(val) is int
|
||||||
return val
|
return val
|
||||||
else:
|
else:
|
||||||
raise SerializeError(
|
raise SerializeError(
|
||||||
|
@ -68,6 +68,7 @@
|
|||||||
#include <torch/csrc/autograd/python_variable.h>
|
#include <torch/csrc/autograd/python_variable.h>
|
||||||
#include <torch/csrc/cpu/Module.h>
|
#include <torch/csrc/cpu/Module.h>
|
||||||
#include <torch/csrc/dynamo/init.h>
|
#include <torch/csrc/dynamo/init.h>
|
||||||
|
#include <torch/csrc/export/pybind.h>
|
||||||
#include <torch/csrc/functorch/init.h>
|
#include <torch/csrc/functorch/init.h>
|
||||||
#include <torch/csrc/fx/node.h>
|
#include <torch/csrc/fx/node.h>
|
||||||
#include <torch/csrc/inductor/aoti_package/pybind.h>
|
#include <torch/csrc/inductor/aoti_package/pybind.h>
|
||||||
@ -1773,6 +1774,7 @@ PyObject* initModule() {
|
|||||||
torch::profiler::initPythonBindings(module);
|
torch::profiler::initPythonBindings(module);
|
||||||
torch::python::init_bindings(module);
|
torch::python::init_bindings(module);
|
||||||
torch::lazy::initLazyBindings(module);
|
torch::lazy::initLazyBindings(module);
|
||||||
|
torch::_export::initExportBindings(module);
|
||||||
torch::inductor::initAOTIRunnerBindings(module);
|
torch::inductor::initAOTIRunnerBindings(module);
|
||||||
torch::inductor::initAOTIPackageBindings(module);
|
torch::inductor::initAOTIPackageBindings(module);
|
||||||
#ifdef USE_ITT
|
#ifdef USE_ITT
|
||||||
|
20
torch/csrc/export/pybind.cpp
Normal file
20
torch/csrc/export/pybind.cpp
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
#include <torch/csrc/utils/generated_serialization_types.h>
|
||||||
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
|
namespace torch::_export {
|
||||||
|
|
||||||
|
void initExportBindings(PyObject* module) {
|
||||||
|
auto rootModule = py::handle(module).cast<py::module>();
|
||||||
|
auto m = rootModule.def_submodule("_export");
|
||||||
|
|
||||||
|
py::class_<ExportedProgram>(m, "CppExportedProgram");
|
||||||
|
|
||||||
|
m.def("deserialize_exported_program", [](const std::string& serialized) {
|
||||||
|
return nlohmann::json::parse(serialized).get<ExportedProgram>();
|
||||||
|
});
|
||||||
|
|
||||||
|
m.def("serialize_exported_program", [](const ExportedProgram& ep) {
|
||||||
|
return nlohmann::json(ep).dump();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // namespace torch::_export
|
7
torch/csrc/export/pybind.h
Normal file
7
torch/csrc/export/pybind.h
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
#include <torch/csrc/python_headers.h>
|
||||||
|
|
||||||
|
namespace torch::_export {
|
||||||
|
|
||||||
|
void initExportBindings(PyObject* module);
|
||||||
|
|
||||||
|
} // namespace torch::_export
|
2188
torch/csrc/utils/generated_serialization_types.h
generated
Normal file
2188
torch/csrc/utils/generated_serialization_types.h
generated
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user