From 3ef2dfc1bae3c2830c87b151190842531ca52dd7 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Thu, 14 Nov 2024 16:34:56 +0000 Subject: [PATCH] [export] Implement cpp deserializer. (#136398) Differential Revision: D63206258 This diff introduces a mechanism to generate a json-compatible deserializer in cpp using nlohmann json (already being used by AOTI). Why we need this? Because there will be a lot of cases where people don't want to use Python to load the graph (e.g. cpp runtime), and instead they can use this header to deserialize the JSON graph. Every time we call update_schema.py to update the schema, the header will be auto generated and included into the source files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136398 Approved by: https://github.com/angelayi --- .gitattributes | 1 + build_variables.bzl | 1 + scripts/export/update_schema.py | 31 +- test/export/test_cpp_serdes.py | 60 + test/export/test_export.py | 6 + test/export/test_schema.py | 28 +- test/export/testing.py | 5 + torch/_C/__init__.pyi.in | 1 + torch/_C/_export.pyi | 10 + torch/_export/serde/schema_check.py | 327 ++- torch/_export/serde/serialize.py | 31 +- torch/csrc/Module.cpp | 2 + torch/csrc/export/pybind.cpp | 20 + torch/csrc/export/pybind.h | 7 + .../utils/generated_serialization_types.h | 2188 +++++++++++++++++ 15 files changed, 2656 insertions(+), 62 deletions(-) create mode 100644 test/export/test_cpp_serdes.py create mode 100644 torch/_C/_export.pyi create mode 100644 torch/csrc/export/pybind.cpp create mode 100644 torch/csrc/export/pybind.h create mode 100644 torch/csrc/utils/generated_serialization_types.h diff --git a/.gitattributes b/.gitattributes index e90430175295..f20bd8f1b31b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -5,3 +5,4 @@ .github/scripts/gql_mocks.json linguist-generated=true third_party/LICENSES_BUNDLED.txt linguist-generated=true tools/build/bazel/requirements.txt linguist-generated=true +torch/csrc/utils/generated_serialization_types.h linguist-generated=true diff --git a/build_variables.bzl b/build_variables.bzl index 9cb351a4a090..b84ed46aacb1 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -844,6 +844,7 @@ libtorch_python_core_sources = [ "torch/csrc/fx/node.cpp", "torch/csrc/mps/Module.cpp", "torch/csrc/mtia/Module.cpp", + "torch/csrc/export/pybind.cpp", "torch/csrc/inductor/aoti_package/pybind.cpp", "torch/csrc/inductor/aoti_runner/pybind.cpp", "torch/csrc/inductor/aoti_eager/kernel_holder.cpp", diff --git a/scripts/export/update_schema.py b/scripts/export/update_schema.py index 74d7dcaaec8c..3de818ef2424 100644 --- a/scripts/export/update_schema.py +++ b/scripts/export/update_schema.py @@ -29,7 +29,7 @@ if __name__ == "__main__": commit = schema_check.update_schema() - if os.path.exists(args.prefix + commit.path): + if os.path.exists(args.prefix + commit.yaml_path): if commit.result["SCHEMA_VERSION"] < commit.base["SCHEMA_VERSION"]: raise RuntimeError( f"Schema version downgraded from {commit.base['SCHEMA_VERSION']} to {commit.result['SCHEMA_VERSION']}." @@ -55,17 +55,28 @@ if __name__ == "__main__": + f"Reason: {reason}" ) - header = ( - "# @" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py" + first_line = ( + "@" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py" ) - header += f"\n# checksum<<{commit.checksum_result}>>" - payload = dump(commit.result, Dumper=Dumper, sort_keys=False) + checksum = f"checksum<<{commit.checksum_result}>>" + yaml_header = "# " + first_line + yaml_header += "\n# " + checksum + yaml_payload = dump(commit.result, Dumper=Dumper, sort_keys=False) - content = header + "\n" + payload + cpp_header = "// " + first_line + cpp_header += "\n// " + checksum + cpp_header += "\n// clang-format off" + cpp_header += "\n" + commit.cpp_header + cpp_header += "\n// clang-format on" + cpp_header += "\n" + + yaml_content = yaml_header + "\n" + yaml_payload if args.dry_run: - print(content) - print("\nWill write the above schema to" + args.prefix + commit.path) + print(yaml_content) + print("\nWill write the above schema to" + args.prefix + commit.yaml_path) else: - with open(args.prefix + commit.path, "w") as f: - f.write(content) + with open(args.prefix + commit.yaml_path, "w") as f: + f.write(yaml_content) + with open(args.prefix + commit.cpp_header_path, "w") as f: + f.write(cpp_header) diff --git a/test/export/test_cpp_serdes.py b/test/export/test_cpp_serdes.py new file mode 100644 index 000000000000..7897e123404d --- /dev/null +++ b/test/export/test_cpp_serdes.py @@ -0,0 +1,60 @@ +# Owner(s): ["oncall: export"] + + +import torch +from torch._export.serde.serialize import deserialize, serialize + + +try: + from . import test_export, testing +except ImportError: + import test_export # @manual=fbcode//caffe2/test:test_export-library + import testing # @manual=fbcode//caffe2/test:test_export-library + +from torch.export import export + + +test_classes = {} + + +def mocked_cpp_serdes_export(*args, **kwargs): + ep = export(*args, **kwargs) + try: + payload = serialize(ep) + except Exception: + return ep + cpp_ep = torch._C._export.deserialize_exported_program(payload.exported_program) + loaded_json = torch._C._export.serialize_exported_program(cpp_ep) + payload.exported_program = loaded_json.encode() + loaded_ep = deserialize(payload) + return loaded_ep + + +def make_dynamic_cls(cls): + cls_prefix = "CppSerdes" + + test_class = testing.make_test_cls_with_mocked_export( + cls, + cls_prefix, + "_cpp_serdes", + mocked_cpp_serdes_export, + xfail_prop="_expected_failure_cpp_serdes", + ) + + test_classes[test_class.__name__] = test_class + # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING + globals()[test_class.__name__] = test_class + test_class.__module__ = __name__ + + +tests = [ + test_export.TestExport, +] +for test in tests: + make_dynamic_cls(test) +del test + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/export/test_export.py b/test/export/test_export.py index 880f8ad256b1..c933c748fd48 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2951,6 +2951,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): export(N(), inputs, dynamic_shapes=dynamic_shapes) @testing.expectedFailureSerDer # no unbacked bindings after deserialization? + @testing.expectedFailureCppSerDes # no unbacked bindings after deserialization? @testing.expectedFailureSerDerNonStrict def test_unbacked_bindings_for_divisible_u_symint(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: @@ -3673,6 +3674,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): self._test_export_same_as_eager(kw_func, args, kwargs) @testing.expectedFailureSerDer # we don't save placeholder metadata + @testing.expectedFailureCppSerDes # we don't save placeholder metadata @testing.expectedFailureSerDerNonStrict @testing.expectedFailureNonStrict @testing.expectedFailureTrainingIRToRunDecompNonStrict # source_fn_stack failure @@ -8078,6 +8080,7 @@ def forward(self, x, y): export(f, (inputs,), dynamic_shapes=dynamic_shapes) @testing.expectedFailureRetraceabilityNonStrict + @testing.expectedFailureCppSerDes # dynamic shape serialization def test_disable_forced_specializations_ok(self): # check that we don't force specialization, and defer to runtime asserts # with allow_complex_guards_as_runtime_asserts=True to successfully export @@ -8198,6 +8201,7 @@ def forward(self, x, y): # TODO requires_grad doesn't seem to work with serialization. @testing.expectedFailureSerDer + @testing.expectedFailureCppSerDes @testing.expectedFailureSerDerNonStrict def test_preserve_requires_grad_placeholders(self): class Module(torch.nn.Module): @@ -8536,6 +8540,7 @@ def forward(self, x, y): ep.graph_module.code ) + @testing.expectedFailureCppSerDes def test_slice_with_floordiv(self): # slice operation emits runtime assert s0//2 <= s1 class M1(torch.nn.Module): @@ -9105,6 +9110,7 @@ def forward(self, x): _load_dynamic_shapes(spec, from_dict=True) @testing.expectedFailureSerDer # TODO(pianpwk): PowByNatural valuerange deserialization + @testing.expectedFailureCppSerDes # TODO(pianpwk): PowByNatural valuerange deserialization @testing.expectedFailureSerDerNonStrict @testing.expectedFailureRetraceabilityNonStrict def test_dim_dynamic(self): diff --git a/test/export/test_schema.py b/test/export/test_schema.py index 7950f40086ac..e98c289520d0 100644 --- a/test/export/test_schema.py +++ b/test/export/test_schema.py @@ -106,11 +106,13 @@ Example(s): commit = _Commit( result=src, checksum_result="", - path="", + yaml_path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", + cpp_header="", + cpp_header_path="", ) next_version, _ = check(commit) self.assertEqual(next_version, [4, 1]) @@ -138,11 +140,13 @@ Example(s): commit = _Commit( result=src, checksum_result="", - path="", + yaml_path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", + cpp_header="", + cpp_header_path="", ) next_version, _ = check(commit) self.assertEqual(next_version, [4, 1]) @@ -173,11 +177,13 @@ Example(s): commit = _Commit( result=src, checksum_result="", - path="", + yaml_path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", + cpp_header="", + cpp_header_path="", ) next_version, _ = check(commit) self.assertEqual(next_version, [3, 3]) @@ -231,11 +237,13 @@ Example(s): commit = _Commit( result=src, checksum_result="", - path="", + yaml_path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", + cpp_header="", + cpp_header_path="", ) next_version, _ = check(commit) self.assertEqual(next_version, [3, 3]) @@ -259,11 +267,13 @@ Example(s): commit = _Commit( result=src, checksum_result="", - path="", + yaml_path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", + cpp_header="", + cpp_header_path="", ) next_version, _ = check(commit) self.assertEqual(next_version, [3, 3]) @@ -294,11 +304,13 @@ Example(s): commit = _Commit( result=src, checksum_result="", - path="", + yaml_path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", + cpp_header="", + cpp_header_path="", ) next_version, _ = check(commit) self.assertEqual(next_version, [3, 3]) @@ -326,11 +338,13 @@ Example(s): commit = _Commit( result=src, checksum_result="", - path="", + yaml_path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", + cpp_header="", + cpp_header_path="", ) next_version, _ = check(commit) self.assertEqual(next_version, [4, 1]) diff --git a/test/export/testing.py b/test/export/testing.py index ed72f219eb63..3e98ac3024a2 100644 --- a/test/export/testing.py +++ b/test/export/testing.py @@ -284,3 +284,8 @@ def expectedFailureSerDerPreDispatch(fn): def expectedFailurePreDispatchRunDecomp(fn): fn._expected_failure_pre_dispatch = True return fn + + +def expectedFailureCppSerDes(fn): + fn._expected_failure_cpp_serdes = True + return fn diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 6dd6dd3ccb4b..ce851a826063 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -64,6 +64,7 @@ from torch.utils._python_dispatch import TorchDispatchMode from . import ( _aoti, + _export, _cpu, _dynamo, _functorch, diff --git a/torch/_C/_export.pyi b/torch/_C/_export.pyi new file mode 100644 index 000000000000..5351945b9d51 --- /dev/null +++ b/torch/_C/_export.pyi @@ -0,0 +1,10 @@ +# Defined in torch/csrc/export/pybind.cpp + +class CppExportedProgram: ... + +def deserialize_exported_program( + serialized_program: str, +) -> CppExportedProgram: ... +def serialize_exported_program( + cpp_exported_program: CppExportedProgram, +) -> str: ... diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index 9116e136cf78..b71c02829ceb 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -1,10 +1,11 @@ # mypy: allow-untyped-defs import dataclasses import hashlib +import inspect import re import typing from enum import IntEnum -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from torch._export.serde import schema from torch._export.serde.union import _Union @@ -20,43 +21,84 @@ def _check(x, msg): def _staged_schema(): - ret: Dict[str, Any] = {} + yaml_ret: Dict[str, Any] = {} defs = {} + cpp_enum_defs: Dict[str, str] = {} + cpp_class_defs: Dict[str, str] = {} + cpp_type_decls: List[str] = [] + cpp_json_defs: List[str] = [] - def _handle_aggregate(ty): - def dump_type(t): + def _handle_aggregate(ty) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def dump_type(t) -> Tuple[str, str]: + TYPE_MAP = { + str: "std::string", + int: "int64_t", + float: "double", + bool: "bool", + } if isinstance(t, type): - return t.__name__ + if t.__name__ in cpp_enum_defs: + return t.__name__, "int64_t" + else: + return t.__name__, TYPE_MAP.get(t, t.__name__) elif isinstance(t, str): assert t in defs - return t + assert t not in cpp_enum_defs + assert "[" not in t + return t, f"ForwardRef<{t}>" elif o := typing.get_origin(t): # Lemme know if there's a better way to do this. if o == list: - head = "List" + yaml_head, cpp_head = "List", "std::vector" elif o == dict: - head = "Dict" + yaml_head, cpp_head = "Dict", "std::unordered_map" elif o == tuple: if typing.get_args(t) == (): - return "Tuple[()]" - head = "Tuple" + return "Tuple[()]", "std::tuple<>" + yaml_head, cpp_head = "Tuple", "std::tuple" elif o == Union: args = typing.get_args(t) assert len(args) == 2 and args[1] == type(None) - return f"Optional[{dump_type(args[0])}]" + yaml_type, cpp_type = dump_type(args[0]) + return f"Optional[{yaml_type}]", f"std::optional<{cpp_type}>" else: raise AssertionError(f"Type {t} is not supported in export schema.") - return ( - f"{head}[{', '.join([dump_type(x) for x in typing.get_args(t)])}]" + yaml_arg_types, cpp_arg_types = zip( + *[dump_type(x) for x in typing.get_args(t)] + ) + return (f"{yaml_head}[{', '.join(yaml_arg_types)}]"), ( + f"{cpp_head}<{', '.join(cpp_arg_types)}>" ) elif t == (): - return "()" + return "()", "" else: raise AssertionError(f"Type {t} is not supported in export schema.") - def dump_field(f): - t = dump_type(f.type) + def dump_cpp_value(v) -> str: + if v is None: + return "std::nullopt" + elif v is True: + return "true" + elif v is False: + return "false" + elif v == {}: + return "{}" + elif v == []: + return "{}" + elif v == (): + return "{}" + elif isinstance(v, str): + return f'"{v}"' + else: + raise AssertionError( + f"Default value {v} is not supported yet in export schema." + ) + + def dump_field(f) -> Tuple[Dict[str, Any], str, Optional[str]]: + t, cpp = dump_type(f.type) ret = {"type": t} + cpp_type = cpp + cpp_default: Optional[str] = None value = dataclasses.MISSING if f.default is not dataclasses.MISSING: @@ -67,24 +109,149 @@ def _staged_schema(): if value is not dataclasses.MISSING: default = str(value) ret["default"] = default + cpp_default = dump_cpp_value(value) if t.startswith("Optional[") and value is not None: raise AssertionError( f"Optional field {ty.__name__}.{f.name} must have default value to be None." ) - return ret + return ret, cpp_type, cpp_default - return {f.name: dump_field(f) for f in dataclasses.fields(ty)} + yaml_ret = {} + cpp_ret = {} + for f in dataclasses.fields(ty): + yaml_res, cpp_type, cpp_default = dump_field(f) + yaml_ret[f.name] = yaml_res + cpp_ret[f.name] = {"cpp_type": cpp_type, "cpp_default": cpp_default} + return yaml_ret, cpp_ret def _handle_int_enum(name, ty): - ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}} + yaml_ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}} + cpp_enum_defs[ + name + ] = f""" +enum class {name} {{ +{chr(10).join([f" {x.name} = {x.value}," for x in ty])} +}}; +""" def _handle_struct(name, ty): - ret[name] = {"kind": "struct", "fields": _handle_aggregate(ty)} + fields, cpp_fields = _handle_aggregate(ty) + yaml_ret[name] = {"kind": "struct", "fields": fields} + field_decls = "\n".join( + f" {f['cpp_type']} {name}{' = ' + f['cpp_default'] if f['cpp_default'] is not None else ''};" + for name, f in cpp_fields.items() + ) + + def accessor(name, ty): + type_name = fields[name]["type"] + if type_name in cpp_enum_defs: + return f""" + {type_name} get_{name}() const {{ + return static_cast<{type_name}>({name}); + }} +""" + return f""" + const {ty}& get_{name}() const {{ + return {name}; + }} +""" + + to_json_decl = f"void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t)" + to_json_def = f"""{{ +{chr(10).join([f' nlohmann_json_j["{name}"] = nlohmann_json_t.{name};' for name, f in cpp_fields.items()])} +}} +""" + from_json_decl = f"void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t)" + + from_json_def = f"""{{ + {name} nlohmann_json_default_obj; +{chr(10).join( + [f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});' + for name, f in cpp_fields.items()])} +}} +""" + cpp_class_defs[ + name + ] = f""" +class {name} {{ + private: +{field_decls} + + public: +{"".join([accessor(name, f["cpp_type"]) for name, f in cpp_fields.items()])} + friend {to_json_decl}; + friend {from_json_decl}; +}}; +""" + cpp_json_defs.append(f"inline {to_json_decl} {to_json_def}") + cpp_json_defs.append(f"inline {from_json_decl} {from_json_def}") + cpp_type_decls.append(f"class {name};") def _handle_union(name, ty): - ret[name] = {"kind": "union", "fields": _handle_aggregate(ty)} + fields, cpp_fields = _handle_aggregate(ty) + yaml_ret[name] = {"kind": "union", "fields": fields} + + def accessor(name, ty, idx): + return f""" + const {ty}& get_{name}() const {{ + return std::get<{idx + 1}>(variant_); + }} +""" + + to_json_branches = "".join( + [ + f""" + if (nlohmann_json_t.tag_ == Tag::{name.upper()}) {{ + nlohmann_json_j["{name}"] = nlohmann_json_t.get_{name}(); + return; + }}""" + for idx, (name, f) in enumerate(cpp_fields.items()) + ] + ) + from_json_branches = "".join( + [ + f""" + if (nlohmann_json_j.contains("{name}")) {{ + nlohmann_json_t.variant_.emplace<{idx + 1}>(nlohmann_json_j.at("{name}").template get<{f["cpp_type"]}>()); + nlohmann_json_t.tag_ = Tag::{name.upper()}; + return; + }}""" + for idx, (name, f) in enumerate(cpp_fields.items()) + ] + ) + + cpp_class_defs[ + name + ] = f""" +class {name} {{ + struct Void {{}}; + + public: + enum class Tag {{ + {", ".join([name.upper() for name in cpp_fields])} + }}; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const {{ + return tag_; + }} +{"".join([accessor(name, f["cpp_type"], idx) for idx, (name, f) in enumerate(cpp_fields.items())])} + friend void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t) {{ +{to_json_branches} + }} + + friend void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t) {{ +{from_json_branches} + }} +}}; +""" + cpp_type_decls.append(f"class {name};") for name in dir(schema): if name.startswith("_"): @@ -97,11 +264,13 @@ def _staged_schema(): defs[name] = value + class_ordering = {} for name, value in defs.items(): if isinstance(value, type): if issubclass(value, IntEnum): _handle_int_enum(name, value) elif dataclasses.is_dataclass(value): + class_ordering[name] = inspect.findsource(value)[1] if issubclass(value, _Union): _handle_union(name, value) else: @@ -113,11 +282,103 @@ def _staged_schema(): else: raise AssertionError(f"Unknown variable {name}: {value}") - ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"]) - assert all(x > 0 for x in ret["SCHEMA_VERSION"]) - ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"] - assert ret["TREESPEC_VERSION"] > 0 - return ret + yaml_ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"]) + assert all(x > 0 for x in yaml_ret["SCHEMA_VERSION"]) + yaml_ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"] + assert yaml_ret["TREESPEC_VERSION"] > 0 + + cpp_header = f""" +#pragma once + +#include +#include +#include +#include +#include + +#include + +#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN +#define NLOHMANN_JSON_NAMESPACE_BEGIN namespace nlohmann {{ +#endif + +#ifndef NLOHMANN_JSON_NAMESPACE_END +#define NLOHMANN_JSON_NAMESPACE_END }} +#endif + +// https://github.com/nlohmann/json/pull/2117 +NLOHMANN_JSON_NAMESPACE_BEGIN +template +struct adl_serializer> {{ + static void to_json(json& j, const std::optional& opt) {{ + if (opt == std::nullopt) {{ + j = nullptr; + }} else {{ + j = *opt; // this will call adl_serializer::to_json which will + // find the free function to_json in T's namespace! + }} + }} + + static void from_json(const json& j, std::optional& opt) {{ + if (j.is_null()) {{ + opt = std::nullopt; + }} else {{ + opt = j.template get(); // same as above, but with + // adl_serializer::from_json + }} + }} +}}; +NLOHMANN_JSON_NAMESPACE_END + +namespace torch {{ +namespace _export {{ + +template +class ForwardRef {{ + static_assert(!std::is_reference_v, "ForwardRef cannot be a reference type"); + + public: + ForwardRef(): ptr_(std::make_unique()) {{}} + ForwardRef(ForwardRef&&) = default; + ForwardRef(const ForwardRef& other): ptr_(std::make_unique(*other.ptr_)) {{}} + ForwardRef& operator=(ForwardRef&&) = default; + ForwardRef& operator=(const ForwardRef& other) {{ + ptr_ = std::make_unique(*other.ptr_); + }} + const T& operator*() const {{ + return *ptr_; + }} + + const T* operator->() const {{ + return ptr_.get(); + }} + + void emplace(T&& t) {{ + ptr_ = std::make_unique(std::move(t)); + }} + + private: + std::unique_ptr ptr_; +}}; + +template +void to_json(nlohmann::json& j, const ForwardRef& p) {{ + j = *p; +}} + +template +void from_json(const nlohmann::json& j, ForwardRef& p) {{ + p.emplace(j.template get()); +}} + +{chr(10).join(cpp_type_decls)} +{"".join(cpp_enum_defs.values())} +{"".join(dict(sorted(cpp_class_defs.items(), key=lambda x: class_ordering[x[0]])).values())} +{chr(10).join(cpp_json_defs)} +}} // namespace _export +}} // namespace torch +""" + return yaml_ret, cpp_header def _diff_schema(dst, src): @@ -193,11 +454,13 @@ def _hash_schema(s): class _Commit: result: Dict[str, Any] checksum_result: str - path: str + yaml_path: str additions: Dict[str, Any] subtractions: Dict[str, Any] base: Dict[str, Any] checksum_base: Optional[str] + cpp_header: str + cpp_header_path: str def update_schema(): @@ -217,16 +480,22 @@ def update_schema(): checksum_base = None dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None} - src = _staged_schema() + src, cpp_header = _staged_schema() additions, subtractions = _diff_schema(dst, src) + yaml_path = __package__.replace(".", "/") + "/schema.yaml" + torch_prefix = "torch/" + assert yaml_path.startswith(torch_prefix) # sanity check + return _Commit( result=src, checksum_result=_hash_schema(src), - path=__package__.replace(".", "/") + "/schema.yaml", + yaml_path=yaml_path, additions=additions, subtractions=subtractions, base=dst, checksum_base=checksum_base, + cpp_header=cpp_header, + cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h", ) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 33641db5dd6b..f3c1ec6cd364 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -697,7 +697,7 @@ class GraphModuleSerializer(metaclass=Final): return inputs def is_sym_int_arg(self, arg) -> bool: - return isinstance(arg, int) or ( + return type(arg) is int or ( isinstance(arg, torch.fx.Node) and arg.name in self.graph_state.sym_int_values ) @@ -770,13 +770,13 @@ class GraphModuleSerializer(metaclass=Final): # For regular FX graph, SymInt arg should be a fx.Node with # self.is_sym_int_arg(arg) being true return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg))) - elif isinstance(arg, bool): + elif type(arg) is bool: return Argument.create(as_bool=arg) - elif isinstance(arg, str): + elif type(arg) is str: return Argument.create(as_string=arg) - elif isinstance(arg, int): + elif type(arg) is int: return Argument.create(as_int=arg) - elif isinstance(arg, float): + elif type(arg) is float: return Argument.create(as_float=arg) elif arg is None: return Argument.create(as_none=()) @@ -814,14 +814,13 @@ class GraphModuleSerializer(metaclass=Final): ) return Argument.create(as_tensors=[]) - # Must check bool first, as bool is also treated as int - if all(isinstance(a, bool) for a in arg): + if all(type(a) is bool for a in arg): return Argument.create(as_bools=list(arg)) - elif all(isinstance(a, int) for a in arg): + elif all(type(a) is int for a in arg): return Argument.create(as_ints=list(arg)) - elif all(isinstance(a, float) for a in arg): + elif all(type(a) is float for a in arg): return Argument.create(as_floats=list(arg)) - elif all(isinstance(a, str) for a in arg): + elif all(type(a) is str for a in arg): return Argument.create(as_strings=list(arg)) elif all(isinstance(a, torch.SymInt) for a in arg): # This is a special branch for handling SymInt args in inductor's @@ -837,7 +836,7 @@ class GraphModuleSerializer(metaclass=Final): for a in arg: if isinstance(a, torch.fx.Node): values.append(SymIntArgument.create(as_name=a.name)) - elif isinstance(a, int): + elif type(a) is int: values.append(SymIntArgument.create(as_int=a)) return Argument.create(as_sym_ints=values) elif all(self.is_sym_bool_arg(a) for a in arg): @@ -952,13 +951,13 @@ class GraphModuleSerializer(metaclass=Final): def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec: if spec.kind == ep.InputKind.USER_INPUT: if isinstance(spec.arg, ep.ConstantArgument): - if isinstance(spec.arg.value, int): + if type(spec.arg.value) is int: constant_spec = ConstantValue.create(as_int=spec.arg.value) - elif isinstance(spec.arg.value, bool): + elif type(spec.arg.value) is bool: constant_spec = ConstantValue.create(as_bool=spec.arg.value) - elif isinstance(spec.arg.value, str): + elif type(spec.arg.value) is str: constant_spec = ConstantValue.create(as_string=spec.arg.value) - elif isinstance(spec.arg.value, float): + elif type(spec.arg.value) is float: constant_spec = ConstantValue.create(as_float=spec.arg.value) elif spec.arg.value is None: constant_spec = ConstantValue.create(as_none=()) @@ -1548,7 +1547,7 @@ class GraphModuleDeserializer(metaclass=Final): return self.shape_env.create_symintnode(sym, hint=hint) elif s.type == "as_int": - assert isinstance(val, int) + assert type(val) is int return val else: raise SerializeError( diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index a9d68d7b4bf2..fbd5c3ea64df 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -68,6 +68,7 @@ #include #include #include +#include #include #include #include @@ -1773,6 +1774,7 @@ PyObject* initModule() { torch::profiler::initPythonBindings(module); torch::python::init_bindings(module); torch::lazy::initLazyBindings(module); + torch::_export::initExportBindings(module); torch::inductor::initAOTIRunnerBindings(module); torch::inductor::initAOTIPackageBindings(module); #ifdef USE_ITT diff --git a/torch/csrc/export/pybind.cpp b/torch/csrc/export/pybind.cpp new file mode 100644 index 000000000000..458b08c3f361 --- /dev/null +++ b/torch/csrc/export/pybind.cpp @@ -0,0 +1,20 @@ +#include +#include + +namespace torch::_export { + +void initExportBindings(PyObject* module) { + auto rootModule = py::handle(module).cast(); + auto m = rootModule.def_submodule("_export"); + + py::class_(m, "CppExportedProgram"); + + m.def("deserialize_exported_program", [](const std::string& serialized) { + return nlohmann::json::parse(serialized).get(); + }); + + m.def("serialize_exported_program", [](const ExportedProgram& ep) { + return nlohmann::json(ep).dump(); + }); +} +} // namespace torch::_export diff --git a/torch/csrc/export/pybind.h b/torch/csrc/export/pybind.h new file mode 100644 index 000000000000..75b954cdccee --- /dev/null +++ b/torch/csrc/export/pybind.h @@ -0,0 +1,7 @@ +#include + +namespace torch::_export { + +void initExportBindings(PyObject* module); + +} // namespace torch::_export diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h new file mode 100644 index 000000000000..934a73022ba1 --- /dev/null +++ b/torch/csrc/utils/generated_serialization_types.h @@ -0,0 +1,2188 @@ +// @generated by update_schema.py +// checksum<<8e27d48014d4ec1c773aef056c0c20b61bead54be8338e95d3347d3422472b9a>> +// clang-format off + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN +#define NLOHMANN_JSON_NAMESPACE_BEGIN namespace nlohmann { +#endif + +#ifndef NLOHMANN_JSON_NAMESPACE_END +#define NLOHMANN_JSON_NAMESPACE_END } +#endif + +// https://github.com/nlohmann/json/pull/2117 +NLOHMANN_JSON_NAMESPACE_BEGIN +template +struct adl_serializer> { + static void to_json(json& j, const std::optional& opt) { + if (opt == std::nullopt) { + j = nullptr; + } else { + j = *opt; // this will call adl_serializer::to_json which will + // find the free function to_json in T's namespace! + } + } + + static void from_json(const json& j, std::optional& opt) { + if (j.is_null()) { + opt = std::nullopt; + } else { + opt = j.template get(); // same as above, but with + // adl_serializer::from_json + } + } +}; +NLOHMANN_JSON_NAMESPACE_END + +namespace torch { +namespace _export { + +template +class ForwardRef { + static_assert(!std::is_reference_v, "ForwardRef cannot be a reference type"); + + public: + ForwardRef(): ptr_(std::make_unique()) {} + ForwardRef(ForwardRef&&) = default; + ForwardRef(const ForwardRef& other): ptr_(std::make_unique(*other.ptr_)) {} + ForwardRef& operator=(ForwardRef&&) = default; + ForwardRef& operator=(const ForwardRef& other) { + ptr_ = std::make_unique(*other.ptr_); + } + const T& operator*() const { + return *ptr_; + } + + const T* operator->() const { + return ptr_.get(); + } + + void emplace(T&& t) { + ptr_ = std::make_unique(std::move(t)); + } + + private: + std::unique_ptr ptr_; +}; + +template +void to_json(nlohmann::json& j, const ForwardRef& p) { + j = *p; +} + +template +void from_json(const nlohmann::json& j, ForwardRef& p) { + p.emplace(j.template get()); +} + +class Argument; +class BufferMutationSpec; +class ConstantInputSpec; +class ConstantValue; +class CustomObjArgument; +class Device; +class ExportedProgram; +class GradientToParameterSpec; +class GradientToUserInputSpec; +class Graph; +class GraphArgument; +class GraphModule; +class GraphSignature; +class InputSpec; +class InputToBufferSpec; +class InputToCustomObjSpec; +class InputToParameterSpec; +class InputToTensorConstantSpec; +class InputTokenSpec; +class LossOutputSpec; +class ModuleCallEntry; +class ModuleCallSignature; +class NamedArgument; +class Node; +class OptionalTensorArgument; +class OutputSpec; +class OutputTokenSpec; +class RangeConstraint; +class SchemaVersion; +class SymBool; +class SymBoolArgument; +class SymExpr; +class SymExprHint; +class SymInt; +class SymIntArgument; +class TensorArgument; +class TensorMeta; +class TokenArgument; +class UserInputMutationSpec; +class UserInputSpec; +class UserOutputSpec; + +enum class Layout { + Unknown = 0, + SparseCoo = 1, + SparseCsr = 2, + SparseCsc = 3, + SparseBsr = 4, + SparseBsc = 5, + _mkldnn = 6, + Strided = 7, +}; + +enum class MemoryFormat { + Unknown = 0, + ContiguousFormat = 1, + ChannelsLast = 2, + ChannelsLast3d = 3, + PreserveFormat = 4, +}; + +enum class ScalarType { + UNKNOWN = 0, + BYTE = 1, + CHAR = 2, + SHORT = 3, + INT = 4, + LONG = 5, + HALF = 6, + FLOAT = 7, + DOUBLE = 8, + COMPLEXHALF = 9, + COMPLEXFLOAT = 10, + COMPLEXDOUBLE = 11, + BOOL = 12, + BFLOAT16 = 13, + UINT16 = 28, +}; + + +class Device { + private: + std::string type; + std::optional index = std::nullopt; + + public: + + const std::string& get_type() const { + return type; + } + + const std::optional& get_index() const { + return index; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Device& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, Device& nlohmann_json_t); +}; + +class SymExprHint { + struct Void {}; + + public: + enum class Tag { + AS_INT, AS_FLOAT, AS_BOOL + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const int64_t& get_as_int() const { + return std::get<1>(variant_); + } + + const double& get_as_float() const { + return std::get<2>(variant_); + } + + const bool& get_as_bool() const { + return std::get<3>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymExprHint& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) { + nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymExprHint& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + if (nlohmann_json_j.contains("as_float")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_float").template get()); + nlohmann_json_t.tag_ = Tag::AS_FLOAT; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + } +}; + +class SymExpr { + private: + std::string expr_str; + std::optional hint = std::nullopt; + + public: + + const std::string& get_expr_str() const { + return expr_str; + } + + const std::optional& get_hint() const { + return hint; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymExpr& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, SymExpr& nlohmann_json_t); +}; + +class SymInt { + struct Void {}; + + public: + enum class Tag { + AS_EXPR, AS_INT + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const SymExpr& get_as_expr() const { + return std::get<1>(variant_); + } + + const int64_t& get_as_int() const { + return std::get<2>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymInt& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_EXPR) { + nlohmann_json_j["as_expr"] = nlohmann_json_t.get_as_expr(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymInt& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_expr")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_expr").template get()); + nlohmann_json_t.tag_ = Tag::AS_EXPR; + return; + } + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + } +}; + +class SymBool { + struct Void {}; + + public: + enum class Tag { + AS_EXPR, AS_BOOL + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const SymExpr& get_as_expr() const { + return std::get<1>(variant_); + } + + const bool& get_as_bool() const { + return std::get<2>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymBool& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_EXPR) { + nlohmann_json_j["as_expr"] = nlohmann_json_t.get_as_expr(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymBool& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_expr")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_expr").template get()); + nlohmann_json_t.tag_ = Tag::AS_EXPR; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + } +}; + +class TensorMeta { + private: + int64_t dtype; + std::vector sizes; + bool requires_grad; + Device device; + std::vector strides; + SymInt storage_offset; + int64_t layout; + + public: + + ScalarType get_dtype() const { + return static_cast(dtype); + } + + const std::vector& get_sizes() const { + return sizes; + } + + const bool& get_requires_grad() const { + return requires_grad; + } + + const Device& get_device() const { + return device; + } + + const std::vector& get_strides() const { + return strides; + } + + const SymInt& get_storage_offset() const { + return storage_offset; + } + + Layout get_layout() const { + return static_cast(layout); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const TensorMeta& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, TensorMeta& nlohmann_json_t); +}; + +class SymIntArgument { + struct Void {}; + + public: + enum class Tag { + AS_NAME, AS_INT + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const std::string& get_as_name() const { + return std::get<1>(variant_); + } + + const int64_t& get_as_int() const { + return std::get<2>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymIntArgument& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_NAME) { + nlohmann_json_j["as_name"] = nlohmann_json_t.get_as_name(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymIntArgument& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_name")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_name").template get()); + nlohmann_json_t.tag_ = Tag::AS_NAME; + return; + } + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + } +}; + +class SymBoolArgument { + struct Void {}; + + public: + enum class Tag { + AS_NAME, AS_BOOL + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const std::string& get_as_name() const { + return std::get<1>(variant_); + } + + const bool& get_as_bool() const { + return std::get<2>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymBoolArgument& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_NAME) { + nlohmann_json_j["as_name"] = nlohmann_json_t.get_as_name(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymBoolArgument& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_name")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_name").template get()); + nlohmann_json_t.tag_ = Tag::AS_NAME; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + } +}; + +class TensorArgument { + private: + std::string name; + + public: + + const std::string& get_name() const { + return name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const TensorArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, TensorArgument& nlohmann_json_t); +}; + +class TokenArgument { + private: + std::string name; + + public: + + const std::string& get_name() const { + return name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const TokenArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, TokenArgument& nlohmann_json_t); +}; + +class OptionalTensorArgument { + struct Void {}; + + public: + enum class Tag { + AS_TENSOR, AS_NONE + }; + + private: + std::variant> variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const TensorArgument& get_as_tensor() const { + return std::get<1>(variant_); + } + + const std::tuple<>& get_as_none() const { + return std::get<2>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const OptionalTensorArgument& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_TENSOR) { + nlohmann_json_j["as_tensor"] = nlohmann_json_t.get_as_tensor(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_NONE) { + nlohmann_json_j["as_none"] = nlohmann_json_t.get_as_none(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, OptionalTensorArgument& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_tensor")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_tensor").template get()); + nlohmann_json_t.tag_ = Tag::AS_TENSOR; + return; + } + if (nlohmann_json_j.contains("as_none")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_none").template get>()); + nlohmann_json_t.tag_ = Tag::AS_NONE; + return; + } + } +}; + +class GraphArgument { + private: + std::string name; + ForwardRef graph; + + public: + + const std::string& get_name() const { + return name; + } + + const ForwardRef& get_graph() const { + return graph; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GraphArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GraphArgument& nlohmann_json_t); +}; + +class CustomObjArgument { + private: + std::string name; + std::string class_fqn; + + public: + + const std::string& get_name() const { + return name; + } + + const std::string& get_class_fqn() const { + return class_fqn; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const CustomObjArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, CustomObjArgument& nlohmann_json_t); +}; + +class Argument { + struct Void {}; + + public: + enum class Tag { + AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR + }; + + private: + std::variant, TensorArgument, std::vector, int64_t, std::vector, double, std::vector, std::string, std::vector, SymIntArgument, std::vector, ScalarType, MemoryFormat, Layout, Device, bool, std::vector, SymBoolArgument, std::vector, GraphArgument, std::vector, CustomObjArgument, std::string> variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const std::tuple<>& get_as_none() const { + return std::get<1>(variant_); + } + + const TensorArgument& get_as_tensor() const { + return std::get<2>(variant_); + } + + const std::vector& get_as_tensors() const { + return std::get<3>(variant_); + } + + const int64_t& get_as_int() const { + return std::get<4>(variant_); + } + + const std::vector& get_as_ints() const { + return std::get<5>(variant_); + } + + const double& get_as_float() const { + return std::get<6>(variant_); + } + + const std::vector& get_as_floats() const { + return std::get<7>(variant_); + } + + const std::string& get_as_string() const { + return std::get<8>(variant_); + } + + const std::vector& get_as_strings() const { + return std::get<9>(variant_); + } + + const SymIntArgument& get_as_sym_int() const { + return std::get<10>(variant_); + } + + const std::vector& get_as_sym_ints() const { + return std::get<11>(variant_); + } + + const ScalarType& get_as_scalar_type() const { + return std::get<12>(variant_); + } + + const MemoryFormat& get_as_memory_format() const { + return std::get<13>(variant_); + } + + const Layout& get_as_layout() const { + return std::get<14>(variant_); + } + + const Device& get_as_device() const { + return std::get<15>(variant_); + } + + const bool& get_as_bool() const { + return std::get<16>(variant_); + } + + const std::vector& get_as_bools() const { + return std::get<17>(variant_); + } + + const SymBoolArgument& get_as_sym_bool() const { + return std::get<18>(variant_); + } + + const std::vector& get_as_sym_bools() const { + return std::get<19>(variant_); + } + + const GraphArgument& get_as_graph() const { + return std::get<20>(variant_); + } + + const std::vector& get_as_optional_tensors() const { + return std::get<21>(variant_); + } + + const CustomObjArgument& get_as_custom_obj() const { + return std::get<22>(variant_); + } + + const std::string& get_as_operator() const { + return std::get<23>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Argument& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_NONE) { + nlohmann_json_j["as_none"] = nlohmann_json_t.get_as_none(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_TENSOR) { + nlohmann_json_j["as_tensor"] = nlohmann_json_t.get_as_tensor(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_TENSORS) { + nlohmann_json_j["as_tensors"] = nlohmann_json_t.get_as_tensors(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INTS) { + nlohmann_json_j["as_ints"] = nlohmann_json_t.get_as_ints(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) { + nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_FLOATS) { + nlohmann_json_j["as_floats"] = nlohmann_json_t.get_as_floats(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_STRING) { + nlohmann_json_j["as_string"] = nlohmann_json_t.get_as_string(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_STRINGS) { + nlohmann_json_j["as_strings"] = nlohmann_json_t.get_as_strings(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SYM_INT) { + nlohmann_json_j["as_sym_int"] = nlohmann_json_t.get_as_sym_int(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SYM_INTS) { + nlohmann_json_j["as_sym_ints"] = nlohmann_json_t.get_as_sym_ints(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SCALAR_TYPE) { + nlohmann_json_j["as_scalar_type"] = nlohmann_json_t.get_as_scalar_type(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_MEMORY_FORMAT) { + nlohmann_json_j["as_memory_format"] = nlohmann_json_t.get_as_memory_format(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_LAYOUT) { + nlohmann_json_j["as_layout"] = nlohmann_json_t.get_as_layout(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_DEVICE) { + nlohmann_json_j["as_device"] = nlohmann_json_t.get_as_device(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOLS) { + nlohmann_json_j["as_bools"] = nlohmann_json_t.get_as_bools(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SYM_BOOL) { + nlohmann_json_j["as_sym_bool"] = nlohmann_json_t.get_as_sym_bool(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SYM_BOOLS) { + nlohmann_json_j["as_sym_bools"] = nlohmann_json_t.get_as_sym_bools(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_GRAPH) { + nlohmann_json_j["as_graph"] = nlohmann_json_t.get_as_graph(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_OPTIONAL_TENSORS) { + nlohmann_json_j["as_optional_tensors"] = nlohmann_json_t.get_as_optional_tensors(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_CUSTOM_OBJ) { + nlohmann_json_j["as_custom_obj"] = nlohmann_json_t.get_as_custom_obj(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_OPERATOR) { + nlohmann_json_j["as_operator"] = nlohmann_json_t.get_as_operator(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, Argument& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_none")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_none").template get>()); + nlohmann_json_t.tag_ = Tag::AS_NONE; + return; + } + if (nlohmann_json_j.contains("as_tensor")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_tensor").template get()); + nlohmann_json_t.tag_ = Tag::AS_TENSOR; + return; + } + if (nlohmann_json_j.contains("as_tensors")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_tensors").template get>()); + nlohmann_json_t.tag_ = Tag::AS_TENSORS; + return; + } + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<4>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + if (nlohmann_json_j.contains("as_ints")) { + nlohmann_json_t.variant_.emplace<5>(nlohmann_json_j.at("as_ints").template get>()); + nlohmann_json_t.tag_ = Tag::AS_INTS; + return; + } + if (nlohmann_json_j.contains("as_float")) { + nlohmann_json_t.variant_.emplace<6>(nlohmann_json_j.at("as_float").template get()); + nlohmann_json_t.tag_ = Tag::AS_FLOAT; + return; + } + if (nlohmann_json_j.contains("as_floats")) { + nlohmann_json_t.variant_.emplace<7>(nlohmann_json_j.at("as_floats").template get>()); + nlohmann_json_t.tag_ = Tag::AS_FLOATS; + return; + } + if (nlohmann_json_j.contains("as_string")) { + nlohmann_json_t.variant_.emplace<8>(nlohmann_json_j.at("as_string").template get()); + nlohmann_json_t.tag_ = Tag::AS_STRING; + return; + } + if (nlohmann_json_j.contains("as_strings")) { + nlohmann_json_t.variant_.emplace<9>(nlohmann_json_j.at("as_strings").template get>()); + nlohmann_json_t.tag_ = Tag::AS_STRINGS; + return; + } + if (nlohmann_json_j.contains("as_sym_int")) { + nlohmann_json_t.variant_.emplace<10>(nlohmann_json_j.at("as_sym_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_SYM_INT; + return; + } + if (nlohmann_json_j.contains("as_sym_ints")) { + nlohmann_json_t.variant_.emplace<11>(nlohmann_json_j.at("as_sym_ints").template get>()); + nlohmann_json_t.tag_ = Tag::AS_SYM_INTS; + return; + } + if (nlohmann_json_j.contains("as_scalar_type")) { + nlohmann_json_t.variant_.emplace<12>(nlohmann_json_j.at("as_scalar_type").template get()); + nlohmann_json_t.tag_ = Tag::AS_SCALAR_TYPE; + return; + } + if (nlohmann_json_j.contains("as_memory_format")) { + nlohmann_json_t.variant_.emplace<13>(nlohmann_json_j.at("as_memory_format").template get()); + nlohmann_json_t.tag_ = Tag::AS_MEMORY_FORMAT; + return; + } + if (nlohmann_json_j.contains("as_layout")) { + nlohmann_json_t.variant_.emplace<14>(nlohmann_json_j.at("as_layout").template get()); + nlohmann_json_t.tag_ = Tag::AS_LAYOUT; + return; + } + if (nlohmann_json_j.contains("as_device")) { + nlohmann_json_t.variant_.emplace<15>(nlohmann_json_j.at("as_device").template get()); + nlohmann_json_t.tag_ = Tag::AS_DEVICE; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<16>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + if (nlohmann_json_j.contains("as_bools")) { + nlohmann_json_t.variant_.emplace<17>(nlohmann_json_j.at("as_bools").template get>()); + nlohmann_json_t.tag_ = Tag::AS_BOOLS; + return; + } + if (nlohmann_json_j.contains("as_sym_bool")) { + nlohmann_json_t.variant_.emplace<18>(nlohmann_json_j.at("as_sym_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_SYM_BOOL; + return; + } + if (nlohmann_json_j.contains("as_sym_bools")) { + nlohmann_json_t.variant_.emplace<19>(nlohmann_json_j.at("as_sym_bools").template get>()); + nlohmann_json_t.tag_ = Tag::AS_SYM_BOOLS; + return; + } + if (nlohmann_json_j.contains("as_graph")) { + nlohmann_json_t.variant_.emplace<20>(nlohmann_json_j.at("as_graph").template get()); + nlohmann_json_t.tag_ = Tag::AS_GRAPH; + return; + } + if (nlohmann_json_j.contains("as_optional_tensors")) { + nlohmann_json_t.variant_.emplace<21>(nlohmann_json_j.at("as_optional_tensors").template get>()); + nlohmann_json_t.tag_ = Tag::AS_OPTIONAL_TENSORS; + return; + } + if (nlohmann_json_j.contains("as_custom_obj")) { + nlohmann_json_t.variant_.emplace<22>(nlohmann_json_j.at("as_custom_obj").template get()); + nlohmann_json_t.tag_ = Tag::AS_CUSTOM_OBJ; + return; + } + if (nlohmann_json_j.contains("as_operator")) { + nlohmann_json_t.variant_.emplace<23>(nlohmann_json_j.at("as_operator").template get()); + nlohmann_json_t.tag_ = Tag::AS_OPERATOR; + return; + } + } +}; + +class NamedArgument { + private: + std::string name; + Argument arg; + + public: + + const std::string& get_name() const { + return name; + } + + const Argument& get_arg() const { + return arg; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const NamedArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, NamedArgument& nlohmann_json_t); +}; + +class Node { + private: + std::string target; + std::vector inputs; + std::vector outputs; + std::unordered_map metadata; + + public: + + const std::string& get_target() const { + return target; + } + + const std::vector& get_inputs() const { + return inputs; + } + + const std::vector& get_outputs() const { + return outputs; + } + + const std::unordered_map& get_metadata() const { + return metadata; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Node& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, Node& nlohmann_json_t); +}; + +class Graph { + private: + std::vector inputs; + std::vector outputs; + std::vector nodes; + std::unordered_map tensor_values; + std::unordered_map sym_int_values; + std::unordered_map sym_bool_values; + bool is_single_tensor_return = false; + std::unordered_map custom_obj_values = {}; + + public: + + const std::vector& get_inputs() const { + return inputs; + } + + const std::vector& get_outputs() const { + return outputs; + } + + const std::vector& get_nodes() const { + return nodes; + } + + const std::unordered_map& get_tensor_values() const { + return tensor_values; + } + + const std::unordered_map& get_sym_int_values() const { + return sym_int_values; + } + + const std::unordered_map& get_sym_bool_values() const { + return sym_bool_values; + } + + const bool& get_is_single_tensor_return() const { + return is_single_tensor_return; + } + + const std::unordered_map& get_custom_obj_values() const { + return custom_obj_values; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Graph& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, Graph& nlohmann_json_t); +}; + +class UserInputSpec { + private: + Argument arg; + + public: + + const Argument& get_arg() const { + return arg; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const UserInputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, UserInputSpec& nlohmann_json_t); +}; + +class ConstantValue { + struct Void {}; + + public: + enum class Tag { + AS_NONE, AS_INT, AS_FLOAT, AS_STRING, AS_BOOL + }; + + private: + std::variant, int64_t, double, std::string, bool> variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const std::tuple<>& get_as_none() const { + return std::get<1>(variant_); + } + + const int64_t& get_as_int() const { + return std::get<2>(variant_); + } + + const double& get_as_float() const { + return std::get<3>(variant_); + } + + const std::string& get_as_string() const { + return std::get<4>(variant_); + } + + const bool& get_as_bool() const { + return std::get<5>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ConstantValue& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_NONE) { + nlohmann_json_j["as_none"] = nlohmann_json_t.get_as_none(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) { + nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_STRING) { + nlohmann_json_j["as_string"] = nlohmann_json_t.get_as_string(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, ConstantValue& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_none")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_none").template get>()); + nlohmann_json_t.tag_ = Tag::AS_NONE; + return; + } + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + if (nlohmann_json_j.contains("as_float")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_float").template get()); + nlohmann_json_t.tag_ = Tag::AS_FLOAT; + return; + } + if (nlohmann_json_j.contains("as_string")) { + nlohmann_json_t.variant_.emplace<4>(nlohmann_json_j.at("as_string").template get()); + nlohmann_json_t.tag_ = Tag::AS_STRING; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<5>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + } +}; + +class ConstantInputSpec { + private: + std::string name; + ConstantValue value; + + public: + + const std::string& get_name() const { + return name; + } + + const ConstantValue& get_value() const { + return value; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ConstantInputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ConstantInputSpec& nlohmann_json_t); +}; + +class InputToParameterSpec { + private: + TensorArgument arg; + std::string parameter_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + const std::string& get_parameter_name() const { + return parameter_name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputToParameterSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputToParameterSpec& nlohmann_json_t); +}; + +class InputToBufferSpec { + private: + TensorArgument arg; + std::string buffer_name; + bool persistent; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + const std::string& get_buffer_name() const { + return buffer_name; + } + + const bool& get_persistent() const { + return persistent; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputToBufferSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputToBufferSpec& nlohmann_json_t); +}; + +class InputToTensorConstantSpec { + private: + TensorArgument arg; + std::string tensor_constant_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + const std::string& get_tensor_constant_name() const { + return tensor_constant_name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputToTensorConstantSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputToTensorConstantSpec& nlohmann_json_t); +}; + +class InputToCustomObjSpec { + private: + CustomObjArgument arg; + std::string custom_obj_name; + + public: + + const CustomObjArgument& get_arg() const { + return arg; + } + + const std::string& get_custom_obj_name() const { + return custom_obj_name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputToCustomObjSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputToCustomObjSpec& nlohmann_json_t); +}; + +class InputTokenSpec { + private: + TokenArgument arg; + + public: + + const TokenArgument& get_arg() const { + return arg; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputTokenSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputTokenSpec& nlohmann_json_t); +}; + +class InputSpec { + struct Void {}; + + public: + enum class Tag { + USER_INPUT, PARAMETER, BUFFER, TENSOR_CONSTANT, CUSTOM_OBJ, TOKEN, CONSTANT_INPUT + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const UserInputSpec& get_user_input() const { + return std::get<1>(variant_); + } + + const InputToParameterSpec& get_parameter() const { + return std::get<2>(variant_); + } + + const InputToBufferSpec& get_buffer() const { + return std::get<3>(variant_); + } + + const InputToTensorConstantSpec& get_tensor_constant() const { + return std::get<4>(variant_); + } + + const InputToCustomObjSpec& get_custom_obj() const { + return std::get<5>(variant_); + } + + const InputTokenSpec& get_token() const { + return std::get<6>(variant_); + } + + const ConstantInputSpec& get_constant_input() const { + return std::get<7>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputSpec& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::USER_INPUT) { + nlohmann_json_j["user_input"] = nlohmann_json_t.get_user_input(); + return; + } + if (nlohmann_json_t.tag_ == Tag::PARAMETER) { + nlohmann_json_j["parameter"] = nlohmann_json_t.get_parameter(); + return; + } + if (nlohmann_json_t.tag_ == Tag::BUFFER) { + nlohmann_json_j["buffer"] = nlohmann_json_t.get_buffer(); + return; + } + if (nlohmann_json_t.tag_ == Tag::TENSOR_CONSTANT) { + nlohmann_json_j["tensor_constant"] = nlohmann_json_t.get_tensor_constant(); + return; + } + if (nlohmann_json_t.tag_ == Tag::CUSTOM_OBJ) { + nlohmann_json_j["custom_obj"] = nlohmann_json_t.get_custom_obj(); + return; + } + if (nlohmann_json_t.tag_ == Tag::TOKEN) { + nlohmann_json_j["token"] = nlohmann_json_t.get_token(); + return; + } + if (nlohmann_json_t.tag_ == Tag::CONSTANT_INPUT) { + nlohmann_json_j["constant_input"] = nlohmann_json_t.get_constant_input(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, InputSpec& nlohmann_json_t) { + + if (nlohmann_json_j.contains("user_input")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("user_input").template get()); + nlohmann_json_t.tag_ = Tag::USER_INPUT; + return; + } + if (nlohmann_json_j.contains("parameter")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("parameter").template get()); + nlohmann_json_t.tag_ = Tag::PARAMETER; + return; + } + if (nlohmann_json_j.contains("buffer")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("buffer").template get()); + nlohmann_json_t.tag_ = Tag::BUFFER; + return; + } + if (nlohmann_json_j.contains("tensor_constant")) { + nlohmann_json_t.variant_.emplace<4>(nlohmann_json_j.at("tensor_constant").template get()); + nlohmann_json_t.tag_ = Tag::TENSOR_CONSTANT; + return; + } + if (nlohmann_json_j.contains("custom_obj")) { + nlohmann_json_t.variant_.emplace<5>(nlohmann_json_j.at("custom_obj").template get()); + nlohmann_json_t.tag_ = Tag::CUSTOM_OBJ; + return; + } + if (nlohmann_json_j.contains("token")) { + nlohmann_json_t.variant_.emplace<6>(nlohmann_json_j.at("token").template get()); + nlohmann_json_t.tag_ = Tag::TOKEN; + return; + } + if (nlohmann_json_j.contains("constant_input")) { + nlohmann_json_t.variant_.emplace<7>(nlohmann_json_j.at("constant_input").template get()); + nlohmann_json_t.tag_ = Tag::CONSTANT_INPUT; + return; + } + } +}; + +class UserOutputSpec { + private: + Argument arg; + + public: + + const Argument& get_arg() const { + return arg; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const UserOutputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, UserOutputSpec& nlohmann_json_t); +}; + +class LossOutputSpec { + private: + TensorArgument arg; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const LossOutputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, LossOutputSpec& nlohmann_json_t); +}; + +class BufferMutationSpec { + private: + TensorArgument arg; + std::string buffer_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + const std::string& get_buffer_name() const { + return buffer_name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const BufferMutationSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, BufferMutationSpec& nlohmann_json_t); +}; + +class GradientToParameterSpec { + private: + TensorArgument arg; + std::string parameter_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + const std::string& get_parameter_name() const { + return parameter_name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GradientToParameterSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GradientToParameterSpec& nlohmann_json_t); +}; + +class GradientToUserInputSpec { + private: + TensorArgument arg; + std::string user_input_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + const std::string& get_user_input_name() const { + return user_input_name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GradientToUserInputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GradientToUserInputSpec& nlohmann_json_t); +}; + +class UserInputMutationSpec { + private: + TensorArgument arg; + std::string user_input_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + const std::string& get_user_input_name() const { + return user_input_name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const UserInputMutationSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, UserInputMutationSpec& nlohmann_json_t); +}; + +class OutputTokenSpec { + private: + TokenArgument arg; + + public: + + const TokenArgument& get_arg() const { + return arg; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const OutputTokenSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, OutputTokenSpec& nlohmann_json_t); +}; + +class OutputSpec { + struct Void {}; + + public: + enum class Tag { + USER_OUTPUT, LOSS_OUTPUT, BUFFER_MUTATION, GRADIENT_TO_PARAMETER, GRADIENT_TO_USER_INPUT, USER_INPUT_MUTATION, TOKEN + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const UserOutputSpec& get_user_output() const { + return std::get<1>(variant_); + } + + const LossOutputSpec& get_loss_output() const { + return std::get<2>(variant_); + } + + const BufferMutationSpec& get_buffer_mutation() const { + return std::get<3>(variant_); + } + + const GradientToParameterSpec& get_gradient_to_parameter() const { + return std::get<4>(variant_); + } + + const GradientToUserInputSpec& get_gradient_to_user_input() const { + return std::get<5>(variant_); + } + + const UserInputMutationSpec& get_user_input_mutation() const { + return std::get<6>(variant_); + } + + const OutputTokenSpec& get_token() const { + return std::get<7>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const OutputSpec& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::USER_OUTPUT) { + nlohmann_json_j["user_output"] = nlohmann_json_t.get_user_output(); + return; + } + if (nlohmann_json_t.tag_ == Tag::LOSS_OUTPUT) { + nlohmann_json_j["loss_output"] = nlohmann_json_t.get_loss_output(); + return; + } + if (nlohmann_json_t.tag_ == Tag::BUFFER_MUTATION) { + nlohmann_json_j["buffer_mutation"] = nlohmann_json_t.get_buffer_mutation(); + return; + } + if (nlohmann_json_t.tag_ == Tag::GRADIENT_TO_PARAMETER) { + nlohmann_json_j["gradient_to_parameter"] = nlohmann_json_t.get_gradient_to_parameter(); + return; + } + if (nlohmann_json_t.tag_ == Tag::GRADIENT_TO_USER_INPUT) { + nlohmann_json_j["gradient_to_user_input"] = nlohmann_json_t.get_gradient_to_user_input(); + return; + } + if (nlohmann_json_t.tag_ == Tag::USER_INPUT_MUTATION) { + nlohmann_json_j["user_input_mutation"] = nlohmann_json_t.get_user_input_mutation(); + return; + } + if (nlohmann_json_t.tag_ == Tag::TOKEN) { + nlohmann_json_j["token"] = nlohmann_json_t.get_token(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, OutputSpec& nlohmann_json_t) { + + if (nlohmann_json_j.contains("user_output")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("user_output").template get()); + nlohmann_json_t.tag_ = Tag::USER_OUTPUT; + return; + } + if (nlohmann_json_j.contains("loss_output")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("loss_output").template get()); + nlohmann_json_t.tag_ = Tag::LOSS_OUTPUT; + return; + } + if (nlohmann_json_j.contains("buffer_mutation")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("buffer_mutation").template get()); + nlohmann_json_t.tag_ = Tag::BUFFER_MUTATION; + return; + } + if (nlohmann_json_j.contains("gradient_to_parameter")) { + nlohmann_json_t.variant_.emplace<4>(nlohmann_json_j.at("gradient_to_parameter").template get()); + nlohmann_json_t.tag_ = Tag::GRADIENT_TO_PARAMETER; + return; + } + if (nlohmann_json_j.contains("gradient_to_user_input")) { + nlohmann_json_t.variant_.emplace<5>(nlohmann_json_j.at("gradient_to_user_input").template get()); + nlohmann_json_t.tag_ = Tag::GRADIENT_TO_USER_INPUT; + return; + } + if (nlohmann_json_j.contains("user_input_mutation")) { + nlohmann_json_t.variant_.emplace<6>(nlohmann_json_j.at("user_input_mutation").template get()); + nlohmann_json_t.tag_ = Tag::USER_INPUT_MUTATION; + return; + } + if (nlohmann_json_j.contains("token")) { + nlohmann_json_t.variant_.emplace<7>(nlohmann_json_j.at("token").template get()); + nlohmann_json_t.tag_ = Tag::TOKEN; + return; + } + } +}; + +class GraphSignature { + private: + std::vector input_specs; + std::vector output_specs; + + public: + + const std::vector& get_input_specs() const { + return input_specs; + } + + const std::vector& get_output_specs() const { + return output_specs; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GraphSignature& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GraphSignature& nlohmann_json_t); +}; + +class RangeConstraint { + private: + std::optional min_val; + std::optional max_val; + + public: + + const std::optional& get_min_val() const { + return min_val; + } + + const std::optional& get_max_val() const { + return max_val; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const RangeConstraint& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, RangeConstraint& nlohmann_json_t); +}; + +class ModuleCallSignature { + private: + std::vector inputs; + std::vector outputs; + std::string in_spec; + std::string out_spec; + std::optional> forward_arg_names = std::nullopt; + + public: + + const std::vector& get_inputs() const { + return inputs; + } + + const std::vector& get_outputs() const { + return outputs; + } + + const std::string& get_in_spec() const { + return in_spec; + } + + const std::string& get_out_spec() const { + return out_spec; + } + + const std::optional>& get_forward_arg_names() const { + return forward_arg_names; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallSignature& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ModuleCallSignature& nlohmann_json_t); +}; + +class ModuleCallEntry { + private: + std::string fqn; + std::optional signature = std::nullopt; + + public: + + const std::string& get_fqn() const { + return fqn; + } + + const std::optional& get_signature() const { + return signature; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallEntry& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ModuleCallEntry& nlohmann_json_t); +}; + +class GraphModule { + private: + Graph graph; + GraphSignature signature; + std::vector module_call_graph; + std::unordered_map metadata = {}; + + public: + + const Graph& get_graph() const { + return graph; + } + + const GraphSignature& get_signature() const { + return signature; + } + + const std::vector& get_module_call_graph() const { + return module_call_graph; + } + + const std::unordered_map& get_metadata() const { + return metadata; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GraphModule& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GraphModule& nlohmann_json_t); +}; + +class SchemaVersion { + private: + int64_t major; + int64_t minor; + + public: + + const int64_t& get_major() const { + return major; + } + + const int64_t& get_minor() const { + return minor; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SchemaVersion& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, SchemaVersion& nlohmann_json_t); +}; + +class ExportedProgram { + private: + GraphModule graph_module; + std::unordered_map opset_version; + std::unordered_map range_constraints; + SchemaVersion schema_version; + std::vector verifiers = {}; + std::string torch_version = "<=2.4"; + + public: + + const GraphModule& get_graph_module() const { + return graph_module; + } + + const std::unordered_map& get_opset_version() const { + return opset_version; + } + + const std::unordered_map& get_range_constraints() const { + return range_constraints; + } + + const SchemaVersion& get_schema_version() const { + return schema_version; + } + + const std::vector& get_verifiers() const { + return verifiers; + } + + const std::string& get_torch_version() const { + return torch_version; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t); +}; + +inline void to_json(nlohmann::json& nlohmann_json_j, const BufferMutationSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["buffer_name"] = nlohmann_json_t.buffer_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, BufferMutationSpec& nlohmann_json_t) { + BufferMutationSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.buffer_name = nlohmann_json_j.value("buffer_name", nlohmann_json_default_obj.buffer_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const ConstantInputSpec& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["value"] = nlohmann_json_t.value; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, ConstantInputSpec& nlohmann_json_t) { + ConstantInputSpec nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.value = nlohmann_json_j.value("value", nlohmann_json_default_obj.value); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const CustomObjArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["class_fqn"] = nlohmann_json_t.class_fqn; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, CustomObjArgument& nlohmann_json_t) { + CustomObjArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.class_fqn = nlohmann_json_j.value("class_fqn", nlohmann_json_default_obj.class_fqn); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const Device& nlohmann_json_t) { + nlohmann_json_j["type"] = nlohmann_json_t.type; + nlohmann_json_j["index"] = nlohmann_json_t.index; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, Device& nlohmann_json_t) { + Device nlohmann_json_default_obj; + nlohmann_json_t.type = nlohmann_json_j.value("type", nlohmann_json_default_obj.type); + nlohmann_json_t.index = nlohmann_json_j.value("index", nlohmann_json_default_obj.index); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nlohmann_json_t) { + nlohmann_json_j["graph_module"] = nlohmann_json_t.graph_module; + nlohmann_json_j["opset_version"] = nlohmann_json_t.opset_version; + nlohmann_json_j["range_constraints"] = nlohmann_json_t.range_constraints; + nlohmann_json_j["schema_version"] = nlohmann_json_t.schema_version; + nlohmann_json_j["verifiers"] = nlohmann_json_t.verifiers; + nlohmann_json_j["torch_version"] = nlohmann_json_t.torch_version; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t) { + ExportedProgram nlohmann_json_default_obj; + nlohmann_json_t.graph_module = nlohmann_json_j.value("graph_module", nlohmann_json_default_obj.graph_module); + nlohmann_json_t.opset_version = nlohmann_json_j.value("opset_version", nlohmann_json_default_obj.opset_version); + nlohmann_json_t.range_constraints = nlohmann_json_j.value("range_constraints", nlohmann_json_default_obj.range_constraints); + nlohmann_json_t.schema_version = nlohmann_json_j.value("schema_version", nlohmann_json_default_obj.schema_version); + nlohmann_json_t.verifiers = nlohmann_json_j.value("verifiers", nlohmann_json_default_obj.verifiers); + nlohmann_json_t.torch_version = nlohmann_json_j.value("torch_version", nlohmann_json_default_obj.torch_version); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GradientToParameterSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["parameter_name"] = nlohmann_json_t.parameter_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GradientToParameterSpec& nlohmann_json_t) { + GradientToParameterSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.parameter_name = nlohmann_json_j.value("parameter_name", nlohmann_json_default_obj.parameter_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GradientToUserInputSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["user_input_name"] = nlohmann_json_t.user_input_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GradientToUserInputSpec& nlohmann_json_t) { + GradientToUserInputSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.user_input_name = nlohmann_json_j.value("user_input_name", nlohmann_json_default_obj.user_input_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const Graph& nlohmann_json_t) { + nlohmann_json_j["inputs"] = nlohmann_json_t.inputs; + nlohmann_json_j["outputs"] = nlohmann_json_t.outputs; + nlohmann_json_j["nodes"] = nlohmann_json_t.nodes; + nlohmann_json_j["tensor_values"] = nlohmann_json_t.tensor_values; + nlohmann_json_j["sym_int_values"] = nlohmann_json_t.sym_int_values; + nlohmann_json_j["sym_bool_values"] = nlohmann_json_t.sym_bool_values; + nlohmann_json_j["is_single_tensor_return"] = nlohmann_json_t.is_single_tensor_return; + nlohmann_json_j["custom_obj_values"] = nlohmann_json_t.custom_obj_values; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, Graph& nlohmann_json_t) { + Graph nlohmann_json_default_obj; + nlohmann_json_t.inputs = nlohmann_json_j.value("inputs", nlohmann_json_default_obj.inputs); + nlohmann_json_t.outputs = nlohmann_json_j.value("outputs", nlohmann_json_default_obj.outputs); + nlohmann_json_t.nodes = nlohmann_json_j.value("nodes", nlohmann_json_default_obj.nodes); + nlohmann_json_t.tensor_values = nlohmann_json_j.value("tensor_values", nlohmann_json_default_obj.tensor_values); + nlohmann_json_t.sym_int_values = nlohmann_json_j.value("sym_int_values", nlohmann_json_default_obj.sym_int_values); + nlohmann_json_t.sym_bool_values = nlohmann_json_j.value("sym_bool_values", nlohmann_json_default_obj.sym_bool_values); + nlohmann_json_t.is_single_tensor_return = nlohmann_json_j.value("is_single_tensor_return", nlohmann_json_default_obj.is_single_tensor_return); + nlohmann_json_t.custom_obj_values = nlohmann_json_j.value("custom_obj_values", nlohmann_json_default_obj.custom_obj_values); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GraphArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["graph"] = nlohmann_json_t.graph; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GraphArgument& nlohmann_json_t) { + GraphArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.graph = nlohmann_json_j.value("graph", nlohmann_json_default_obj.graph); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GraphModule& nlohmann_json_t) { + nlohmann_json_j["graph"] = nlohmann_json_t.graph; + nlohmann_json_j["signature"] = nlohmann_json_t.signature; + nlohmann_json_j["module_call_graph"] = nlohmann_json_t.module_call_graph; + nlohmann_json_j["metadata"] = nlohmann_json_t.metadata; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GraphModule& nlohmann_json_t) { + GraphModule nlohmann_json_default_obj; + nlohmann_json_t.graph = nlohmann_json_j.value("graph", nlohmann_json_default_obj.graph); + nlohmann_json_t.signature = nlohmann_json_j.value("signature", nlohmann_json_default_obj.signature); + nlohmann_json_t.module_call_graph = nlohmann_json_j.value("module_call_graph", nlohmann_json_default_obj.module_call_graph); + nlohmann_json_t.metadata = nlohmann_json_j.value("metadata", nlohmann_json_default_obj.metadata); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GraphSignature& nlohmann_json_t) { + nlohmann_json_j["input_specs"] = nlohmann_json_t.input_specs; + nlohmann_json_j["output_specs"] = nlohmann_json_t.output_specs; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GraphSignature& nlohmann_json_t) { + GraphSignature nlohmann_json_default_obj; + nlohmann_json_t.input_specs = nlohmann_json_j.value("input_specs", nlohmann_json_default_obj.input_specs); + nlohmann_json_t.output_specs = nlohmann_json_j.value("output_specs", nlohmann_json_default_obj.output_specs); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputToBufferSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["buffer_name"] = nlohmann_json_t.buffer_name; + nlohmann_json_j["persistent"] = nlohmann_json_t.persistent; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputToBufferSpec& nlohmann_json_t) { + InputToBufferSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.buffer_name = nlohmann_json_j.value("buffer_name", nlohmann_json_default_obj.buffer_name); + nlohmann_json_t.persistent = nlohmann_json_j.value("persistent", nlohmann_json_default_obj.persistent); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputToCustomObjSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["custom_obj_name"] = nlohmann_json_t.custom_obj_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputToCustomObjSpec& nlohmann_json_t) { + InputToCustomObjSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.custom_obj_name = nlohmann_json_j.value("custom_obj_name", nlohmann_json_default_obj.custom_obj_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputToParameterSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["parameter_name"] = nlohmann_json_t.parameter_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputToParameterSpec& nlohmann_json_t) { + InputToParameterSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.parameter_name = nlohmann_json_j.value("parameter_name", nlohmann_json_default_obj.parameter_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputToTensorConstantSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["tensor_constant_name"] = nlohmann_json_t.tensor_constant_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputToTensorConstantSpec& nlohmann_json_t) { + InputToTensorConstantSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.tensor_constant_name = nlohmann_json_j.value("tensor_constant_name", nlohmann_json_default_obj.tensor_constant_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputTokenSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputTokenSpec& nlohmann_json_t) { + InputTokenSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const LossOutputSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, LossOutputSpec& nlohmann_json_t) { + LossOutputSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallEntry& nlohmann_json_t) { + nlohmann_json_j["fqn"] = nlohmann_json_t.fqn; + nlohmann_json_j["signature"] = nlohmann_json_t.signature; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, ModuleCallEntry& nlohmann_json_t) { + ModuleCallEntry nlohmann_json_default_obj; + nlohmann_json_t.fqn = nlohmann_json_j.value("fqn", nlohmann_json_default_obj.fqn); + nlohmann_json_t.signature = nlohmann_json_j.value("signature", nlohmann_json_default_obj.signature); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallSignature& nlohmann_json_t) { + nlohmann_json_j["inputs"] = nlohmann_json_t.inputs; + nlohmann_json_j["outputs"] = nlohmann_json_t.outputs; + nlohmann_json_j["in_spec"] = nlohmann_json_t.in_spec; + nlohmann_json_j["out_spec"] = nlohmann_json_t.out_spec; + nlohmann_json_j["forward_arg_names"] = nlohmann_json_t.forward_arg_names; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, ModuleCallSignature& nlohmann_json_t) { + ModuleCallSignature nlohmann_json_default_obj; + nlohmann_json_t.inputs = nlohmann_json_j.value("inputs", nlohmann_json_default_obj.inputs); + nlohmann_json_t.outputs = nlohmann_json_j.value("outputs", nlohmann_json_default_obj.outputs); + nlohmann_json_t.in_spec = nlohmann_json_j.value("in_spec", nlohmann_json_default_obj.in_spec); + nlohmann_json_t.out_spec = nlohmann_json_j.value("out_spec", nlohmann_json_default_obj.out_spec); + nlohmann_json_t.forward_arg_names = nlohmann_json_j.value("forward_arg_names", nlohmann_json_default_obj.forward_arg_names); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const NamedArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, NamedArgument& nlohmann_json_t) { + NamedArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const Node& nlohmann_json_t) { + nlohmann_json_j["target"] = nlohmann_json_t.target; + nlohmann_json_j["inputs"] = nlohmann_json_t.inputs; + nlohmann_json_j["outputs"] = nlohmann_json_t.outputs; + nlohmann_json_j["metadata"] = nlohmann_json_t.metadata; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, Node& nlohmann_json_t) { + Node nlohmann_json_default_obj; + nlohmann_json_t.target = nlohmann_json_j.value("target", nlohmann_json_default_obj.target); + nlohmann_json_t.inputs = nlohmann_json_j.value("inputs", nlohmann_json_default_obj.inputs); + nlohmann_json_t.outputs = nlohmann_json_j.value("outputs", nlohmann_json_default_obj.outputs); + nlohmann_json_t.metadata = nlohmann_json_j.value("metadata", nlohmann_json_default_obj.metadata); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const OutputTokenSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, OutputTokenSpec& nlohmann_json_t) { + OutputTokenSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const RangeConstraint& nlohmann_json_t) { + nlohmann_json_j["min_val"] = nlohmann_json_t.min_val; + nlohmann_json_j["max_val"] = nlohmann_json_t.max_val; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, RangeConstraint& nlohmann_json_t) { + RangeConstraint nlohmann_json_default_obj; + nlohmann_json_t.min_val = nlohmann_json_j.value("min_val", nlohmann_json_default_obj.min_val); + nlohmann_json_t.max_val = nlohmann_json_j.value("max_val", nlohmann_json_default_obj.max_val); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const SchemaVersion& nlohmann_json_t) { + nlohmann_json_j["major"] = nlohmann_json_t.major; + nlohmann_json_j["minor"] = nlohmann_json_t.minor; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, SchemaVersion& nlohmann_json_t) { + SchemaVersion nlohmann_json_default_obj; + nlohmann_json_t.major = nlohmann_json_j.value("major", nlohmann_json_default_obj.major); + nlohmann_json_t.minor = nlohmann_json_j.value("minor", nlohmann_json_default_obj.minor); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const SymExpr& nlohmann_json_t) { + nlohmann_json_j["expr_str"] = nlohmann_json_t.expr_str; + nlohmann_json_j["hint"] = nlohmann_json_t.hint; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, SymExpr& nlohmann_json_t) { + SymExpr nlohmann_json_default_obj; + nlohmann_json_t.expr_str = nlohmann_json_j.value("expr_str", nlohmann_json_default_obj.expr_str); + nlohmann_json_t.hint = nlohmann_json_j.value("hint", nlohmann_json_default_obj.hint); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const TensorArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, TensorArgument& nlohmann_json_t) { + TensorArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const TensorMeta& nlohmann_json_t) { + nlohmann_json_j["dtype"] = nlohmann_json_t.dtype; + nlohmann_json_j["sizes"] = nlohmann_json_t.sizes; + nlohmann_json_j["requires_grad"] = nlohmann_json_t.requires_grad; + nlohmann_json_j["device"] = nlohmann_json_t.device; + nlohmann_json_j["strides"] = nlohmann_json_t.strides; + nlohmann_json_j["storage_offset"] = nlohmann_json_t.storage_offset; + nlohmann_json_j["layout"] = nlohmann_json_t.layout; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, TensorMeta& nlohmann_json_t) { + TensorMeta nlohmann_json_default_obj; + nlohmann_json_t.dtype = nlohmann_json_j.value("dtype", nlohmann_json_default_obj.dtype); + nlohmann_json_t.sizes = nlohmann_json_j.value("sizes", nlohmann_json_default_obj.sizes); + nlohmann_json_t.requires_grad = nlohmann_json_j.value("requires_grad", nlohmann_json_default_obj.requires_grad); + nlohmann_json_t.device = nlohmann_json_j.value("device", nlohmann_json_default_obj.device); + nlohmann_json_t.strides = nlohmann_json_j.value("strides", nlohmann_json_default_obj.strides); + nlohmann_json_t.storage_offset = nlohmann_json_j.value("storage_offset", nlohmann_json_default_obj.storage_offset); + nlohmann_json_t.layout = nlohmann_json_j.value("layout", nlohmann_json_default_obj.layout); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const TokenArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, TokenArgument& nlohmann_json_t) { + TokenArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const UserInputMutationSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["user_input_name"] = nlohmann_json_t.user_input_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, UserInputMutationSpec& nlohmann_json_t) { + UserInputMutationSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.user_input_name = nlohmann_json_j.value("user_input_name", nlohmann_json_default_obj.user_input_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const UserInputSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, UserInputSpec& nlohmann_json_t) { + UserInputSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const UserOutputSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, UserOutputSpec& nlohmann_json_t) { + UserOutputSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +} // namespace _export +} // namespace torch + +// clang-format on