diff --git a/.gitattributes b/.gitattributes index e90430175295..f20bd8f1b31b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -5,3 +5,4 @@ .github/scripts/gql_mocks.json linguist-generated=true third_party/LICENSES_BUNDLED.txt linguist-generated=true tools/build/bazel/requirements.txt linguist-generated=true +torch/csrc/utils/generated_serialization_types.h linguist-generated=true diff --git a/build_variables.bzl b/build_variables.bzl index 9cb351a4a090..b84ed46aacb1 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -844,6 +844,7 @@ libtorch_python_core_sources = [ "torch/csrc/fx/node.cpp", "torch/csrc/mps/Module.cpp", "torch/csrc/mtia/Module.cpp", + "torch/csrc/export/pybind.cpp", "torch/csrc/inductor/aoti_package/pybind.cpp", "torch/csrc/inductor/aoti_runner/pybind.cpp", "torch/csrc/inductor/aoti_eager/kernel_holder.cpp", diff --git a/scripts/export/update_schema.py b/scripts/export/update_schema.py index 74d7dcaaec8c..3de818ef2424 100644 --- a/scripts/export/update_schema.py +++ b/scripts/export/update_schema.py @@ -29,7 +29,7 @@ if __name__ == "__main__": commit = schema_check.update_schema() - if os.path.exists(args.prefix + commit.path): + if os.path.exists(args.prefix + commit.yaml_path): if commit.result["SCHEMA_VERSION"] < commit.base["SCHEMA_VERSION"]: raise RuntimeError( f"Schema version downgraded from {commit.base['SCHEMA_VERSION']} to {commit.result['SCHEMA_VERSION']}." @@ -55,17 +55,28 @@ if __name__ == "__main__": + f"Reason: {reason}" ) - header = ( - "# @" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py" + first_line = ( + "@" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py" ) - header += f"\n# checksum<<{commit.checksum_result}>>" - payload = dump(commit.result, Dumper=Dumper, sort_keys=False) + checksum = f"checksum<<{commit.checksum_result}>>" + yaml_header = "# " + first_line + yaml_header += "\n# " + checksum + yaml_payload = dump(commit.result, Dumper=Dumper, sort_keys=False) - content = header + "\n" + payload + cpp_header = "// " + first_line + cpp_header += "\n// " + checksum + cpp_header += "\n// clang-format off" + cpp_header += "\n" + commit.cpp_header + cpp_header += "\n// clang-format on" + cpp_header += "\n" + + yaml_content = yaml_header + "\n" + yaml_payload if args.dry_run: - print(content) - print("\nWill write the above schema to" + args.prefix + commit.path) + print(yaml_content) + print("\nWill write the above schema to" + args.prefix + commit.yaml_path) else: - with open(args.prefix + commit.path, "w") as f: - f.write(content) + with open(args.prefix + commit.yaml_path, "w") as f: + f.write(yaml_content) + with open(args.prefix + commit.cpp_header_path, "w") as f: + f.write(cpp_header) diff --git a/test/export/test_cpp_serdes.py b/test/export/test_cpp_serdes.py new file mode 100644 index 000000000000..7897e123404d --- /dev/null +++ b/test/export/test_cpp_serdes.py @@ -0,0 +1,60 @@ +# Owner(s): ["oncall: export"] + + +import torch +from torch._export.serde.serialize import deserialize, serialize + + +try: + from . import test_export, testing +except ImportError: + import test_export # @manual=fbcode//caffe2/test:test_export-library + import testing # @manual=fbcode//caffe2/test:test_export-library + +from torch.export import export + + +test_classes = {} + + +def mocked_cpp_serdes_export(*args, **kwargs): + ep = export(*args, **kwargs) + try: + payload = serialize(ep) + except Exception: + return ep + cpp_ep = torch._C._export.deserialize_exported_program(payload.exported_program) + loaded_json = torch._C._export.serialize_exported_program(cpp_ep) + payload.exported_program = loaded_json.encode() + loaded_ep = deserialize(payload) + return loaded_ep + + +def make_dynamic_cls(cls): + cls_prefix = "CppSerdes" + + test_class = testing.make_test_cls_with_mocked_export( + cls, + cls_prefix, + "_cpp_serdes", + mocked_cpp_serdes_export, + xfail_prop="_expected_failure_cpp_serdes", + ) + + test_classes[test_class.__name__] = test_class + # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING + globals()[test_class.__name__] = test_class + test_class.__module__ = __name__ + + +tests = [ + test_export.TestExport, +] +for test in tests: + make_dynamic_cls(test) +del test + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/export/test_export.py b/test/export/test_export.py index 880f8ad256b1..c933c748fd48 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2951,6 +2951,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): export(N(), inputs, dynamic_shapes=dynamic_shapes) @testing.expectedFailureSerDer # no unbacked bindings after deserialization? + @testing.expectedFailureCppSerDes # no unbacked bindings after deserialization? @testing.expectedFailureSerDerNonStrict def test_unbacked_bindings_for_divisible_u_symint(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: @@ -3673,6 +3674,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): self._test_export_same_as_eager(kw_func, args, kwargs) @testing.expectedFailureSerDer # we don't save placeholder metadata + @testing.expectedFailureCppSerDes # we don't save placeholder metadata @testing.expectedFailureSerDerNonStrict @testing.expectedFailureNonStrict @testing.expectedFailureTrainingIRToRunDecompNonStrict # source_fn_stack failure @@ -8078,6 +8080,7 @@ def forward(self, x, y): export(f, (inputs,), dynamic_shapes=dynamic_shapes) @testing.expectedFailureRetraceabilityNonStrict + @testing.expectedFailureCppSerDes # dynamic shape serialization def test_disable_forced_specializations_ok(self): # check that we don't force specialization, and defer to runtime asserts # with allow_complex_guards_as_runtime_asserts=True to successfully export @@ -8198,6 +8201,7 @@ def forward(self, x, y): # TODO requires_grad doesn't seem to work with serialization. @testing.expectedFailureSerDer + @testing.expectedFailureCppSerDes @testing.expectedFailureSerDerNonStrict def test_preserve_requires_grad_placeholders(self): class Module(torch.nn.Module): @@ -8536,6 +8540,7 @@ def forward(self, x, y): ep.graph_module.code ) + @testing.expectedFailureCppSerDes def test_slice_with_floordiv(self): # slice operation emits runtime assert s0//2 <= s1 class M1(torch.nn.Module): @@ -9105,6 +9110,7 @@ def forward(self, x): _load_dynamic_shapes(spec, from_dict=True) @testing.expectedFailureSerDer # TODO(pianpwk): PowByNatural valuerange deserialization + @testing.expectedFailureCppSerDes # TODO(pianpwk): PowByNatural valuerange deserialization @testing.expectedFailureSerDerNonStrict @testing.expectedFailureRetraceabilityNonStrict def test_dim_dynamic(self): diff --git a/test/export/test_schema.py b/test/export/test_schema.py index 7950f40086ac..e98c289520d0 100644 --- a/test/export/test_schema.py +++ b/test/export/test_schema.py @@ -106,11 +106,13 @@ Example(s): commit = _Commit( result=src, checksum_result="", - path="", + yaml_path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", + cpp_header="", + cpp_header_path="", ) next_version, _ = check(commit) self.assertEqual(next_version, [4, 1]) @@ -138,11 +140,13 @@ Example(s): commit = _Commit( result=src, checksum_result="", - path="", + yaml_path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", + cpp_header="", + cpp_header_path="", ) next_version, _ = check(commit) self.assertEqual(next_version, [4, 1]) @@ -173,11 +177,13 @@ Example(s): commit = _Commit( result=src, checksum_result="", - path="", + yaml_path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", + cpp_header="", + cpp_header_path="", ) next_version, _ = check(commit) self.assertEqual(next_version, [3, 3]) @@ -231,11 +237,13 @@ Example(s): commit = _Commit( result=src, checksum_result="", - path="", + yaml_path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", + cpp_header="", + cpp_header_path="", ) next_version, _ = check(commit) self.assertEqual(next_version, [3, 3]) @@ -259,11 +267,13 @@ Example(s): commit = _Commit( result=src, checksum_result="", - path="", + yaml_path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", + cpp_header="", + cpp_header_path="", ) next_version, _ = check(commit) self.assertEqual(next_version, [3, 3]) @@ -294,11 +304,13 @@ Example(s): commit = _Commit( result=src, checksum_result="", - path="", + yaml_path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", + cpp_header="", + cpp_header_path="", ) next_version, _ = check(commit) self.assertEqual(next_version, [3, 3]) @@ -326,11 +338,13 @@ Example(s): commit = _Commit( result=src, checksum_result="", - path="", + yaml_path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", + cpp_header="", + cpp_header_path="", ) next_version, _ = check(commit) self.assertEqual(next_version, [4, 1]) diff --git a/test/export/testing.py b/test/export/testing.py index ed72f219eb63..3e98ac3024a2 100644 --- a/test/export/testing.py +++ b/test/export/testing.py @@ -284,3 +284,8 @@ def expectedFailureSerDerPreDispatch(fn): def expectedFailurePreDispatchRunDecomp(fn): fn._expected_failure_pre_dispatch = True return fn + + +def expectedFailureCppSerDes(fn): + fn._expected_failure_cpp_serdes = True + return fn diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 6dd6dd3ccb4b..ce851a826063 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -64,6 +64,7 @@ from torch.utils._python_dispatch import TorchDispatchMode from . import ( _aoti, + _export, _cpu, _dynamo, _functorch, diff --git a/torch/_C/_export.pyi b/torch/_C/_export.pyi new file mode 100644 index 000000000000..5351945b9d51 --- /dev/null +++ b/torch/_C/_export.pyi @@ -0,0 +1,10 @@ +# Defined in torch/csrc/export/pybind.cpp + +class CppExportedProgram: ... + +def deserialize_exported_program( + serialized_program: str, +) -> CppExportedProgram: ... +def serialize_exported_program( + cpp_exported_program: CppExportedProgram, +) -> str: ... diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index 9116e136cf78..b71c02829ceb 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -1,10 +1,11 @@ # mypy: allow-untyped-defs import dataclasses import hashlib +import inspect import re import typing from enum import IntEnum -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from torch._export.serde import schema from torch._export.serde.union import _Union @@ -20,43 +21,84 @@ def _check(x, msg): def _staged_schema(): - ret: Dict[str, Any] = {} + yaml_ret: Dict[str, Any] = {} defs = {} + cpp_enum_defs: Dict[str, str] = {} + cpp_class_defs: Dict[str, str] = {} + cpp_type_decls: List[str] = [] + cpp_json_defs: List[str] = [] - def _handle_aggregate(ty): - def dump_type(t): + def _handle_aggregate(ty) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def dump_type(t) -> Tuple[str, str]: + TYPE_MAP = { + str: "std::string", + int: "int64_t", + float: "double", + bool: "bool", + } if isinstance(t, type): - return t.__name__ + if t.__name__ in cpp_enum_defs: + return t.__name__, "int64_t" + else: + return t.__name__, TYPE_MAP.get(t, t.__name__) elif isinstance(t, str): assert t in defs - return t + assert t not in cpp_enum_defs + assert "[" not in t + return t, f"ForwardRef<{t}>" elif o := typing.get_origin(t): # Lemme know if there's a better way to do this. if o == list: - head = "List" + yaml_head, cpp_head = "List", "std::vector" elif o == dict: - head = "Dict" + yaml_head, cpp_head = "Dict", "std::unordered_map" elif o == tuple: if typing.get_args(t) == (): - return "Tuple[()]" - head = "Tuple" + return "Tuple[()]", "std::tuple<>" + yaml_head, cpp_head = "Tuple", "std::tuple" elif o == Union: args = typing.get_args(t) assert len(args) == 2 and args[1] == type(None) - return f"Optional[{dump_type(args[0])}]" + yaml_type, cpp_type = dump_type(args[0]) + return f"Optional[{yaml_type}]", f"std::optional<{cpp_type}>" else: raise AssertionError(f"Type {t} is not supported in export schema.") - return ( - f"{head}[{', '.join([dump_type(x) for x in typing.get_args(t)])}]" + yaml_arg_types, cpp_arg_types = zip( + *[dump_type(x) for x in typing.get_args(t)] + ) + return (f"{yaml_head}[{', '.join(yaml_arg_types)}]"), ( + f"{cpp_head}<{', '.join(cpp_arg_types)}>" ) elif t == (): - return "()" + return "()", "" else: raise AssertionError(f"Type {t} is not supported in export schema.") - def dump_field(f): - t = dump_type(f.type) + def dump_cpp_value(v) -> str: + if v is None: + return "std::nullopt" + elif v is True: + return "true" + elif v is False: + return "false" + elif v == {}: + return "{}" + elif v == []: + return "{}" + elif v == (): + return "{}" + elif isinstance(v, str): + return f'"{v}"' + else: + raise AssertionError( + f"Default value {v} is not supported yet in export schema." + ) + + def dump_field(f) -> Tuple[Dict[str, Any], str, Optional[str]]: + t, cpp = dump_type(f.type) ret = {"type": t} + cpp_type = cpp + cpp_default: Optional[str] = None value = dataclasses.MISSING if f.default is not dataclasses.MISSING: @@ -67,24 +109,149 @@ def _staged_schema(): if value is not dataclasses.MISSING: default = str(value) ret["default"] = default + cpp_default = dump_cpp_value(value) if t.startswith("Optional[") and value is not None: raise AssertionError( f"Optional field {ty.__name__}.{f.name} must have default value to be None." ) - return ret + return ret, cpp_type, cpp_default - return {f.name: dump_field(f) for f in dataclasses.fields(ty)} + yaml_ret = {} + cpp_ret = {} + for f in dataclasses.fields(ty): + yaml_res, cpp_type, cpp_default = dump_field(f) + yaml_ret[f.name] = yaml_res + cpp_ret[f.name] = {"cpp_type": cpp_type, "cpp_default": cpp_default} + return yaml_ret, cpp_ret def _handle_int_enum(name, ty): - ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}} + yaml_ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}} + cpp_enum_defs[ + name + ] = f""" +enum class {name} {{ +{chr(10).join([f" {x.name} = {x.value}," for x in ty])} +}}; +""" def _handle_struct(name, ty): - ret[name] = {"kind": "struct", "fields": _handle_aggregate(ty)} + fields, cpp_fields = _handle_aggregate(ty) + yaml_ret[name] = {"kind": "struct", "fields": fields} + field_decls = "\n".join( + f" {f['cpp_type']} {name}{' = ' + f['cpp_default'] if f['cpp_default'] is not None else ''};" + for name, f in cpp_fields.items() + ) + + def accessor(name, ty): + type_name = fields[name]["type"] + if type_name in cpp_enum_defs: + return f""" + {type_name} get_{name}() const {{ + return static_cast<{type_name}>({name}); + }} +""" + return f""" + const {ty}& get_{name}() const {{ + return {name}; + }} +""" + + to_json_decl = f"void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t)" + to_json_def = f"""{{ +{chr(10).join([f' nlohmann_json_j["{name}"] = nlohmann_json_t.{name};' for name, f in cpp_fields.items()])} +}} +""" + from_json_decl = f"void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t)" + + from_json_def = f"""{{ + {name} nlohmann_json_default_obj; +{chr(10).join( + [f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});' + for name, f in cpp_fields.items()])} +}} +""" + cpp_class_defs[ + name + ] = f""" +class {name} {{ + private: +{field_decls} + + public: +{"".join([accessor(name, f["cpp_type"]) for name, f in cpp_fields.items()])} + friend {to_json_decl}; + friend {from_json_decl}; +}}; +""" + cpp_json_defs.append(f"inline {to_json_decl} {to_json_def}") + cpp_json_defs.append(f"inline {from_json_decl} {from_json_def}") + cpp_type_decls.append(f"class {name};") def _handle_union(name, ty): - ret[name] = {"kind": "union", "fields": _handle_aggregate(ty)} + fields, cpp_fields = _handle_aggregate(ty) + yaml_ret[name] = {"kind": "union", "fields": fields} + + def accessor(name, ty, idx): + return f""" + const {ty}& get_{name}() const {{ + return std::get<{idx + 1}>(variant_); + }} +""" + + to_json_branches = "".join( + [ + f""" + if (nlohmann_json_t.tag_ == Tag::{name.upper()}) {{ + nlohmann_json_j["{name}"] = nlohmann_json_t.get_{name}(); + return; + }}""" + for idx, (name, f) in enumerate(cpp_fields.items()) + ] + ) + from_json_branches = "".join( + [ + f""" + if (nlohmann_json_j.contains("{name}")) {{ + nlohmann_json_t.variant_.emplace<{idx + 1}>(nlohmann_json_j.at("{name}").template get<{f["cpp_type"]}>()); + nlohmann_json_t.tag_ = Tag::{name.upper()}; + return; + }}""" + for idx, (name, f) in enumerate(cpp_fields.items()) + ] + ) + + cpp_class_defs[ + name + ] = f""" +class {name} {{ + struct Void {{}}; + + public: + enum class Tag {{ + {", ".join([name.upper() for name in cpp_fields])} + }}; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const {{ + return tag_; + }} +{"".join([accessor(name, f["cpp_type"], idx) for idx, (name, f) in enumerate(cpp_fields.items())])} + friend void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t) {{ +{to_json_branches} + }} + + friend void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t) {{ +{from_json_branches} + }} +}}; +""" + cpp_type_decls.append(f"class {name};") for name in dir(schema): if name.startswith("_"): @@ -97,11 +264,13 @@ def _staged_schema(): defs[name] = value + class_ordering = {} for name, value in defs.items(): if isinstance(value, type): if issubclass(value, IntEnum): _handle_int_enum(name, value) elif dataclasses.is_dataclass(value): + class_ordering[name] = inspect.findsource(value)[1] if issubclass(value, _Union): _handle_union(name, value) else: @@ -113,11 +282,103 @@ def _staged_schema(): else: raise AssertionError(f"Unknown variable {name}: {value}") - ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"]) - assert all(x > 0 for x in ret["SCHEMA_VERSION"]) - ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"] - assert ret["TREESPEC_VERSION"] > 0 - return ret + yaml_ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"]) + assert all(x > 0 for x in yaml_ret["SCHEMA_VERSION"]) + yaml_ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"] + assert yaml_ret["TREESPEC_VERSION"] > 0 + + cpp_header = f""" +#pragma once + +#include +#include +#include +#include +#include + +#include + +#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN +#define NLOHMANN_JSON_NAMESPACE_BEGIN namespace nlohmann {{ +#endif + +#ifndef NLOHMANN_JSON_NAMESPACE_END +#define NLOHMANN_JSON_NAMESPACE_END }} +#endif + +// https://github.com/nlohmann/json/pull/2117 +NLOHMANN_JSON_NAMESPACE_BEGIN +template +struct adl_serializer> {{ + static void to_json(json& j, const std::optional& opt) {{ + if (opt == std::nullopt) {{ + j = nullptr; + }} else {{ + j = *opt; // this will call adl_serializer::to_json which will + // find the free function to_json in T's namespace! + }} + }} + + static void from_json(const json& j, std::optional& opt) {{ + if (j.is_null()) {{ + opt = std::nullopt; + }} else {{ + opt = j.template get(); // same as above, but with + // adl_serializer::from_json + }} + }} +}}; +NLOHMANN_JSON_NAMESPACE_END + +namespace torch {{ +namespace _export {{ + +template +class ForwardRef {{ + static_assert(!std::is_reference_v, "ForwardRef cannot be a reference type"); + + public: + ForwardRef(): ptr_(std::make_unique()) {{}} + ForwardRef(ForwardRef&&) = default; + ForwardRef(const ForwardRef& other): ptr_(std::make_unique(*other.ptr_)) {{}} + ForwardRef& operator=(ForwardRef&&) = default; + ForwardRef& operator=(const ForwardRef& other) {{ + ptr_ = std::make_unique(*other.ptr_); + }} + const T& operator*() const {{ + return *ptr_; + }} + + const T* operator->() const {{ + return ptr_.get(); + }} + + void emplace(T&& t) {{ + ptr_ = std::make_unique(std::move(t)); + }} + + private: + std::unique_ptr ptr_; +}}; + +template +void to_json(nlohmann::json& j, const ForwardRef& p) {{ + j = *p; +}} + +template +void from_json(const nlohmann::json& j, ForwardRef& p) {{ + p.emplace(j.template get()); +}} + +{chr(10).join(cpp_type_decls)} +{"".join(cpp_enum_defs.values())} +{"".join(dict(sorted(cpp_class_defs.items(), key=lambda x: class_ordering[x[0]])).values())} +{chr(10).join(cpp_json_defs)} +}} // namespace _export +}} // namespace torch +""" + return yaml_ret, cpp_header def _diff_schema(dst, src): @@ -193,11 +454,13 @@ def _hash_schema(s): class _Commit: result: Dict[str, Any] checksum_result: str - path: str + yaml_path: str additions: Dict[str, Any] subtractions: Dict[str, Any] base: Dict[str, Any] checksum_base: Optional[str] + cpp_header: str + cpp_header_path: str def update_schema(): @@ -217,16 +480,22 @@ def update_schema(): checksum_base = None dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None} - src = _staged_schema() + src, cpp_header = _staged_schema() additions, subtractions = _diff_schema(dst, src) + yaml_path = __package__.replace(".", "/") + "/schema.yaml" + torch_prefix = "torch/" + assert yaml_path.startswith(torch_prefix) # sanity check + return _Commit( result=src, checksum_result=_hash_schema(src), - path=__package__.replace(".", "/") + "/schema.yaml", + yaml_path=yaml_path, additions=additions, subtractions=subtractions, base=dst, checksum_base=checksum_base, + cpp_header=cpp_header, + cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h", ) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 33641db5dd6b..f3c1ec6cd364 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -697,7 +697,7 @@ class GraphModuleSerializer(metaclass=Final): return inputs def is_sym_int_arg(self, arg) -> bool: - return isinstance(arg, int) or ( + return type(arg) is int or ( isinstance(arg, torch.fx.Node) and arg.name in self.graph_state.sym_int_values ) @@ -770,13 +770,13 @@ class GraphModuleSerializer(metaclass=Final): # For regular FX graph, SymInt arg should be a fx.Node with # self.is_sym_int_arg(arg) being true return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg))) - elif isinstance(arg, bool): + elif type(arg) is bool: return Argument.create(as_bool=arg) - elif isinstance(arg, str): + elif type(arg) is str: return Argument.create(as_string=arg) - elif isinstance(arg, int): + elif type(arg) is int: return Argument.create(as_int=arg) - elif isinstance(arg, float): + elif type(arg) is float: return Argument.create(as_float=arg) elif arg is None: return Argument.create(as_none=()) @@ -814,14 +814,13 @@ class GraphModuleSerializer(metaclass=Final): ) return Argument.create(as_tensors=[]) - # Must check bool first, as bool is also treated as int - if all(isinstance(a, bool) for a in arg): + if all(type(a) is bool for a in arg): return Argument.create(as_bools=list(arg)) - elif all(isinstance(a, int) for a in arg): + elif all(type(a) is int for a in arg): return Argument.create(as_ints=list(arg)) - elif all(isinstance(a, float) for a in arg): + elif all(type(a) is float for a in arg): return Argument.create(as_floats=list(arg)) - elif all(isinstance(a, str) for a in arg): + elif all(type(a) is str for a in arg): return Argument.create(as_strings=list(arg)) elif all(isinstance(a, torch.SymInt) for a in arg): # This is a special branch for handling SymInt args in inductor's @@ -837,7 +836,7 @@ class GraphModuleSerializer(metaclass=Final): for a in arg: if isinstance(a, torch.fx.Node): values.append(SymIntArgument.create(as_name=a.name)) - elif isinstance(a, int): + elif type(a) is int: values.append(SymIntArgument.create(as_int=a)) return Argument.create(as_sym_ints=values) elif all(self.is_sym_bool_arg(a) for a in arg): @@ -952,13 +951,13 @@ class GraphModuleSerializer(metaclass=Final): def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec: if spec.kind == ep.InputKind.USER_INPUT: if isinstance(spec.arg, ep.ConstantArgument): - if isinstance(spec.arg.value, int): + if type(spec.arg.value) is int: constant_spec = ConstantValue.create(as_int=spec.arg.value) - elif isinstance(spec.arg.value, bool): + elif type(spec.arg.value) is bool: constant_spec = ConstantValue.create(as_bool=spec.arg.value) - elif isinstance(spec.arg.value, str): + elif type(spec.arg.value) is str: constant_spec = ConstantValue.create(as_string=spec.arg.value) - elif isinstance(spec.arg.value, float): + elif type(spec.arg.value) is float: constant_spec = ConstantValue.create(as_float=spec.arg.value) elif spec.arg.value is None: constant_spec = ConstantValue.create(as_none=()) @@ -1548,7 +1547,7 @@ class GraphModuleDeserializer(metaclass=Final): return self.shape_env.create_symintnode(sym, hint=hint) elif s.type == "as_int": - assert isinstance(val, int) + assert type(val) is int return val else: raise SerializeError( diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index a9d68d7b4bf2..fbd5c3ea64df 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -68,6 +68,7 @@ #include #include #include +#include #include #include #include @@ -1773,6 +1774,7 @@ PyObject* initModule() { torch::profiler::initPythonBindings(module); torch::python::init_bindings(module); torch::lazy::initLazyBindings(module); + torch::_export::initExportBindings(module); torch::inductor::initAOTIRunnerBindings(module); torch::inductor::initAOTIPackageBindings(module); #ifdef USE_ITT diff --git a/torch/csrc/export/pybind.cpp b/torch/csrc/export/pybind.cpp new file mode 100644 index 000000000000..458b08c3f361 --- /dev/null +++ b/torch/csrc/export/pybind.cpp @@ -0,0 +1,20 @@ +#include +#include + +namespace torch::_export { + +void initExportBindings(PyObject* module) { + auto rootModule = py::handle(module).cast(); + auto m = rootModule.def_submodule("_export"); + + py::class_(m, "CppExportedProgram"); + + m.def("deserialize_exported_program", [](const std::string& serialized) { + return nlohmann::json::parse(serialized).get(); + }); + + m.def("serialize_exported_program", [](const ExportedProgram& ep) { + return nlohmann::json(ep).dump(); + }); +} +} // namespace torch::_export diff --git a/torch/csrc/export/pybind.h b/torch/csrc/export/pybind.h new file mode 100644 index 000000000000..75b954cdccee --- /dev/null +++ b/torch/csrc/export/pybind.h @@ -0,0 +1,7 @@ +#include + +namespace torch::_export { + +void initExportBindings(PyObject* module); + +} // namespace torch::_export diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h new file mode 100644 index 000000000000..934a73022ba1 --- /dev/null +++ b/torch/csrc/utils/generated_serialization_types.h @@ -0,0 +1,2188 @@ +// @generated by update_schema.py +// checksum<<8e27d48014d4ec1c773aef056c0c20b61bead54be8338e95d3347d3422472b9a>> +// clang-format off + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN +#define NLOHMANN_JSON_NAMESPACE_BEGIN namespace nlohmann { +#endif + +#ifndef NLOHMANN_JSON_NAMESPACE_END +#define NLOHMANN_JSON_NAMESPACE_END } +#endif + +// https://github.com/nlohmann/json/pull/2117 +NLOHMANN_JSON_NAMESPACE_BEGIN +template +struct adl_serializer> { + static void to_json(json& j, const std::optional& opt) { + if (opt == std::nullopt) { + j = nullptr; + } else { + j = *opt; // this will call adl_serializer::to_json which will + // find the free function to_json in T's namespace! + } + } + + static void from_json(const json& j, std::optional& opt) { + if (j.is_null()) { + opt = std::nullopt; + } else { + opt = j.template get(); // same as above, but with + // adl_serializer::from_json + } + } +}; +NLOHMANN_JSON_NAMESPACE_END + +namespace torch { +namespace _export { + +template +class ForwardRef { + static_assert(!std::is_reference_v, "ForwardRef cannot be a reference type"); + + public: + ForwardRef(): ptr_(std::make_unique()) {} + ForwardRef(ForwardRef&&) = default; + ForwardRef(const ForwardRef& other): ptr_(std::make_unique(*other.ptr_)) {} + ForwardRef& operator=(ForwardRef&&) = default; + ForwardRef& operator=(const ForwardRef& other) { + ptr_ = std::make_unique(*other.ptr_); + } + const T& operator*() const { + return *ptr_; + } + + const T* operator->() const { + return ptr_.get(); + } + + void emplace(T&& t) { + ptr_ = std::make_unique(std::move(t)); + } + + private: + std::unique_ptr ptr_; +}; + +template +void to_json(nlohmann::json& j, const ForwardRef& p) { + j = *p; +} + +template +void from_json(const nlohmann::json& j, ForwardRef& p) { + p.emplace(j.template get()); +} + +class Argument; +class BufferMutationSpec; +class ConstantInputSpec; +class ConstantValue; +class CustomObjArgument; +class Device; +class ExportedProgram; +class GradientToParameterSpec; +class GradientToUserInputSpec; +class Graph; +class GraphArgument; +class GraphModule; +class GraphSignature; +class InputSpec; +class InputToBufferSpec; +class InputToCustomObjSpec; +class InputToParameterSpec; +class InputToTensorConstantSpec; +class InputTokenSpec; +class LossOutputSpec; +class ModuleCallEntry; +class ModuleCallSignature; +class NamedArgument; +class Node; +class OptionalTensorArgument; +class OutputSpec; +class OutputTokenSpec; +class RangeConstraint; +class SchemaVersion; +class SymBool; +class SymBoolArgument; +class SymExpr; +class SymExprHint; +class SymInt; +class SymIntArgument; +class TensorArgument; +class TensorMeta; +class TokenArgument; +class UserInputMutationSpec; +class UserInputSpec; +class UserOutputSpec; + +enum class Layout { + Unknown = 0, + SparseCoo = 1, + SparseCsr = 2, + SparseCsc = 3, + SparseBsr = 4, + SparseBsc = 5, + _mkldnn = 6, + Strided = 7, +}; + +enum class MemoryFormat { + Unknown = 0, + ContiguousFormat = 1, + ChannelsLast = 2, + ChannelsLast3d = 3, + PreserveFormat = 4, +}; + +enum class ScalarType { + UNKNOWN = 0, + BYTE = 1, + CHAR = 2, + SHORT = 3, + INT = 4, + LONG = 5, + HALF = 6, + FLOAT = 7, + DOUBLE = 8, + COMPLEXHALF = 9, + COMPLEXFLOAT = 10, + COMPLEXDOUBLE = 11, + BOOL = 12, + BFLOAT16 = 13, + UINT16 = 28, +}; + + +class Device { + private: + std::string type; + std::optional index = std::nullopt; + + public: + + const std::string& get_type() const { + return type; + } + + const std::optional& get_index() const { + return index; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Device& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, Device& nlohmann_json_t); +}; + +class SymExprHint { + struct Void {}; + + public: + enum class Tag { + AS_INT, AS_FLOAT, AS_BOOL + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const int64_t& get_as_int() const { + return std::get<1>(variant_); + } + + const double& get_as_float() const { + return std::get<2>(variant_); + } + + const bool& get_as_bool() const { + return std::get<3>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymExprHint& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) { + nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymExprHint& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + if (nlohmann_json_j.contains("as_float")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_float").template get()); + nlohmann_json_t.tag_ = Tag::AS_FLOAT; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + } +}; + +class SymExpr { + private: + std::string expr_str; + std::optional hint = std::nullopt; + + public: + + const std::string& get_expr_str() const { + return expr_str; + } + + const std::optional& get_hint() const { + return hint; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymExpr& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, SymExpr& nlohmann_json_t); +}; + +class SymInt { + struct Void {}; + + public: + enum class Tag { + AS_EXPR, AS_INT + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const SymExpr& get_as_expr() const { + return std::get<1>(variant_); + } + + const int64_t& get_as_int() const { + return std::get<2>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymInt& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_EXPR) { + nlohmann_json_j["as_expr"] = nlohmann_json_t.get_as_expr(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymInt& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_expr")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_expr").template get()); + nlohmann_json_t.tag_ = Tag::AS_EXPR; + return; + } + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + } +}; + +class SymBool { + struct Void {}; + + public: + enum class Tag { + AS_EXPR, AS_BOOL + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const SymExpr& get_as_expr() const { + return std::get<1>(variant_); + } + + const bool& get_as_bool() const { + return std::get<2>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymBool& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_EXPR) { + nlohmann_json_j["as_expr"] = nlohmann_json_t.get_as_expr(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymBool& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_expr")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_expr").template get()); + nlohmann_json_t.tag_ = Tag::AS_EXPR; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + } +}; + +class TensorMeta { + private: + int64_t dtype; + std::vector sizes; + bool requires_grad; + Device device; + std::vector strides; + SymInt storage_offset; + int64_t layout; + + public: + + ScalarType get_dtype() const { + return static_cast(dtype); + } + + const std::vector& get_sizes() const { + return sizes; + } + + const bool& get_requires_grad() const { + return requires_grad; + } + + const Device& get_device() const { + return device; + } + + const std::vector& get_strides() const { + return strides; + } + + const SymInt& get_storage_offset() const { + return storage_offset; + } + + Layout get_layout() const { + return static_cast(layout); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const TensorMeta& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, TensorMeta& nlohmann_json_t); +}; + +class SymIntArgument { + struct Void {}; + + public: + enum class Tag { + AS_NAME, AS_INT + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const std::string& get_as_name() const { + return std::get<1>(variant_); + } + + const int64_t& get_as_int() const { + return std::get<2>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymIntArgument& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_NAME) { + nlohmann_json_j["as_name"] = nlohmann_json_t.get_as_name(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymIntArgument& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_name")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_name").template get()); + nlohmann_json_t.tag_ = Tag::AS_NAME; + return; + } + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + } +}; + +class SymBoolArgument { + struct Void {}; + + public: + enum class Tag { + AS_NAME, AS_BOOL + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const std::string& get_as_name() const { + return std::get<1>(variant_); + } + + const bool& get_as_bool() const { + return std::get<2>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SymBoolArgument& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_NAME) { + nlohmann_json_j["as_name"] = nlohmann_json_t.get_as_name(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, SymBoolArgument& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_name")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_name").template get()); + nlohmann_json_t.tag_ = Tag::AS_NAME; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + } +}; + +class TensorArgument { + private: + std::string name; + + public: + + const std::string& get_name() const { + return name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const TensorArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, TensorArgument& nlohmann_json_t); +}; + +class TokenArgument { + private: + std::string name; + + public: + + const std::string& get_name() const { + return name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const TokenArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, TokenArgument& nlohmann_json_t); +}; + +class OptionalTensorArgument { + struct Void {}; + + public: + enum class Tag { + AS_TENSOR, AS_NONE + }; + + private: + std::variant> variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const TensorArgument& get_as_tensor() const { + return std::get<1>(variant_); + } + + const std::tuple<>& get_as_none() const { + return std::get<2>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const OptionalTensorArgument& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_TENSOR) { + nlohmann_json_j["as_tensor"] = nlohmann_json_t.get_as_tensor(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_NONE) { + nlohmann_json_j["as_none"] = nlohmann_json_t.get_as_none(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, OptionalTensorArgument& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_tensor")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_tensor").template get()); + nlohmann_json_t.tag_ = Tag::AS_TENSOR; + return; + } + if (nlohmann_json_j.contains("as_none")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_none").template get>()); + nlohmann_json_t.tag_ = Tag::AS_NONE; + return; + } + } +}; + +class GraphArgument { + private: + std::string name; + ForwardRef graph; + + public: + + const std::string& get_name() const { + return name; + } + + const ForwardRef& get_graph() const { + return graph; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GraphArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GraphArgument& nlohmann_json_t); +}; + +class CustomObjArgument { + private: + std::string name; + std::string class_fqn; + + public: + + const std::string& get_name() const { + return name; + } + + const std::string& get_class_fqn() const { + return class_fqn; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const CustomObjArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, CustomObjArgument& nlohmann_json_t); +}; + +class Argument { + struct Void {}; + + public: + enum class Tag { + AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR + }; + + private: + std::variant, TensorArgument, std::vector, int64_t, std::vector, double, std::vector, std::string, std::vector, SymIntArgument, std::vector, ScalarType, MemoryFormat, Layout, Device, bool, std::vector, SymBoolArgument, std::vector, GraphArgument, std::vector, CustomObjArgument, std::string> variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const std::tuple<>& get_as_none() const { + return std::get<1>(variant_); + } + + const TensorArgument& get_as_tensor() const { + return std::get<2>(variant_); + } + + const std::vector& get_as_tensors() const { + return std::get<3>(variant_); + } + + const int64_t& get_as_int() const { + return std::get<4>(variant_); + } + + const std::vector& get_as_ints() const { + return std::get<5>(variant_); + } + + const double& get_as_float() const { + return std::get<6>(variant_); + } + + const std::vector& get_as_floats() const { + return std::get<7>(variant_); + } + + const std::string& get_as_string() const { + return std::get<8>(variant_); + } + + const std::vector& get_as_strings() const { + return std::get<9>(variant_); + } + + const SymIntArgument& get_as_sym_int() const { + return std::get<10>(variant_); + } + + const std::vector& get_as_sym_ints() const { + return std::get<11>(variant_); + } + + const ScalarType& get_as_scalar_type() const { + return std::get<12>(variant_); + } + + const MemoryFormat& get_as_memory_format() const { + return std::get<13>(variant_); + } + + const Layout& get_as_layout() const { + return std::get<14>(variant_); + } + + const Device& get_as_device() const { + return std::get<15>(variant_); + } + + const bool& get_as_bool() const { + return std::get<16>(variant_); + } + + const std::vector& get_as_bools() const { + return std::get<17>(variant_); + } + + const SymBoolArgument& get_as_sym_bool() const { + return std::get<18>(variant_); + } + + const std::vector& get_as_sym_bools() const { + return std::get<19>(variant_); + } + + const GraphArgument& get_as_graph() const { + return std::get<20>(variant_); + } + + const std::vector& get_as_optional_tensors() const { + return std::get<21>(variant_); + } + + const CustomObjArgument& get_as_custom_obj() const { + return std::get<22>(variant_); + } + + const std::string& get_as_operator() const { + return std::get<23>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Argument& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_NONE) { + nlohmann_json_j["as_none"] = nlohmann_json_t.get_as_none(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_TENSOR) { + nlohmann_json_j["as_tensor"] = nlohmann_json_t.get_as_tensor(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_TENSORS) { + nlohmann_json_j["as_tensors"] = nlohmann_json_t.get_as_tensors(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INTS) { + nlohmann_json_j["as_ints"] = nlohmann_json_t.get_as_ints(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) { + nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_FLOATS) { + nlohmann_json_j["as_floats"] = nlohmann_json_t.get_as_floats(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_STRING) { + nlohmann_json_j["as_string"] = nlohmann_json_t.get_as_string(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_STRINGS) { + nlohmann_json_j["as_strings"] = nlohmann_json_t.get_as_strings(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SYM_INT) { + nlohmann_json_j["as_sym_int"] = nlohmann_json_t.get_as_sym_int(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SYM_INTS) { + nlohmann_json_j["as_sym_ints"] = nlohmann_json_t.get_as_sym_ints(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SCALAR_TYPE) { + nlohmann_json_j["as_scalar_type"] = nlohmann_json_t.get_as_scalar_type(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_MEMORY_FORMAT) { + nlohmann_json_j["as_memory_format"] = nlohmann_json_t.get_as_memory_format(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_LAYOUT) { + nlohmann_json_j["as_layout"] = nlohmann_json_t.get_as_layout(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_DEVICE) { + nlohmann_json_j["as_device"] = nlohmann_json_t.get_as_device(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOLS) { + nlohmann_json_j["as_bools"] = nlohmann_json_t.get_as_bools(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SYM_BOOL) { + nlohmann_json_j["as_sym_bool"] = nlohmann_json_t.get_as_sym_bool(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_SYM_BOOLS) { + nlohmann_json_j["as_sym_bools"] = nlohmann_json_t.get_as_sym_bools(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_GRAPH) { + nlohmann_json_j["as_graph"] = nlohmann_json_t.get_as_graph(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_OPTIONAL_TENSORS) { + nlohmann_json_j["as_optional_tensors"] = nlohmann_json_t.get_as_optional_tensors(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_CUSTOM_OBJ) { + nlohmann_json_j["as_custom_obj"] = nlohmann_json_t.get_as_custom_obj(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_OPERATOR) { + nlohmann_json_j["as_operator"] = nlohmann_json_t.get_as_operator(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, Argument& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_none")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_none").template get>()); + nlohmann_json_t.tag_ = Tag::AS_NONE; + return; + } + if (nlohmann_json_j.contains("as_tensor")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_tensor").template get()); + nlohmann_json_t.tag_ = Tag::AS_TENSOR; + return; + } + if (nlohmann_json_j.contains("as_tensors")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_tensors").template get>()); + nlohmann_json_t.tag_ = Tag::AS_TENSORS; + return; + } + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<4>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + if (nlohmann_json_j.contains("as_ints")) { + nlohmann_json_t.variant_.emplace<5>(nlohmann_json_j.at("as_ints").template get>()); + nlohmann_json_t.tag_ = Tag::AS_INTS; + return; + } + if (nlohmann_json_j.contains("as_float")) { + nlohmann_json_t.variant_.emplace<6>(nlohmann_json_j.at("as_float").template get()); + nlohmann_json_t.tag_ = Tag::AS_FLOAT; + return; + } + if (nlohmann_json_j.contains("as_floats")) { + nlohmann_json_t.variant_.emplace<7>(nlohmann_json_j.at("as_floats").template get>()); + nlohmann_json_t.tag_ = Tag::AS_FLOATS; + return; + } + if (nlohmann_json_j.contains("as_string")) { + nlohmann_json_t.variant_.emplace<8>(nlohmann_json_j.at("as_string").template get()); + nlohmann_json_t.tag_ = Tag::AS_STRING; + return; + } + if (nlohmann_json_j.contains("as_strings")) { + nlohmann_json_t.variant_.emplace<9>(nlohmann_json_j.at("as_strings").template get>()); + nlohmann_json_t.tag_ = Tag::AS_STRINGS; + return; + } + if (nlohmann_json_j.contains("as_sym_int")) { + nlohmann_json_t.variant_.emplace<10>(nlohmann_json_j.at("as_sym_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_SYM_INT; + return; + } + if (nlohmann_json_j.contains("as_sym_ints")) { + nlohmann_json_t.variant_.emplace<11>(nlohmann_json_j.at("as_sym_ints").template get>()); + nlohmann_json_t.tag_ = Tag::AS_SYM_INTS; + return; + } + if (nlohmann_json_j.contains("as_scalar_type")) { + nlohmann_json_t.variant_.emplace<12>(nlohmann_json_j.at("as_scalar_type").template get()); + nlohmann_json_t.tag_ = Tag::AS_SCALAR_TYPE; + return; + } + if (nlohmann_json_j.contains("as_memory_format")) { + nlohmann_json_t.variant_.emplace<13>(nlohmann_json_j.at("as_memory_format").template get()); + nlohmann_json_t.tag_ = Tag::AS_MEMORY_FORMAT; + return; + } + if (nlohmann_json_j.contains("as_layout")) { + nlohmann_json_t.variant_.emplace<14>(nlohmann_json_j.at("as_layout").template get()); + nlohmann_json_t.tag_ = Tag::AS_LAYOUT; + return; + } + if (nlohmann_json_j.contains("as_device")) { + nlohmann_json_t.variant_.emplace<15>(nlohmann_json_j.at("as_device").template get()); + nlohmann_json_t.tag_ = Tag::AS_DEVICE; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<16>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + if (nlohmann_json_j.contains("as_bools")) { + nlohmann_json_t.variant_.emplace<17>(nlohmann_json_j.at("as_bools").template get>()); + nlohmann_json_t.tag_ = Tag::AS_BOOLS; + return; + } + if (nlohmann_json_j.contains("as_sym_bool")) { + nlohmann_json_t.variant_.emplace<18>(nlohmann_json_j.at("as_sym_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_SYM_BOOL; + return; + } + if (nlohmann_json_j.contains("as_sym_bools")) { + nlohmann_json_t.variant_.emplace<19>(nlohmann_json_j.at("as_sym_bools").template get>()); + nlohmann_json_t.tag_ = Tag::AS_SYM_BOOLS; + return; + } + if (nlohmann_json_j.contains("as_graph")) { + nlohmann_json_t.variant_.emplace<20>(nlohmann_json_j.at("as_graph").template get()); + nlohmann_json_t.tag_ = Tag::AS_GRAPH; + return; + } + if (nlohmann_json_j.contains("as_optional_tensors")) { + nlohmann_json_t.variant_.emplace<21>(nlohmann_json_j.at("as_optional_tensors").template get>()); + nlohmann_json_t.tag_ = Tag::AS_OPTIONAL_TENSORS; + return; + } + if (nlohmann_json_j.contains("as_custom_obj")) { + nlohmann_json_t.variant_.emplace<22>(nlohmann_json_j.at("as_custom_obj").template get()); + nlohmann_json_t.tag_ = Tag::AS_CUSTOM_OBJ; + return; + } + if (nlohmann_json_j.contains("as_operator")) { + nlohmann_json_t.variant_.emplace<23>(nlohmann_json_j.at("as_operator").template get()); + nlohmann_json_t.tag_ = Tag::AS_OPERATOR; + return; + } + } +}; + +class NamedArgument { + private: + std::string name; + Argument arg; + + public: + + const std::string& get_name() const { + return name; + } + + const Argument& get_arg() const { + return arg; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const NamedArgument& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, NamedArgument& nlohmann_json_t); +}; + +class Node { + private: + std::string target; + std::vector inputs; + std::vector outputs; + std::unordered_map metadata; + + public: + + const std::string& get_target() const { + return target; + } + + const std::vector& get_inputs() const { + return inputs; + } + + const std::vector& get_outputs() const { + return outputs; + } + + const std::unordered_map& get_metadata() const { + return metadata; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Node& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, Node& nlohmann_json_t); +}; + +class Graph { + private: + std::vector inputs; + std::vector outputs; + std::vector nodes; + std::unordered_map tensor_values; + std::unordered_map sym_int_values; + std::unordered_map sym_bool_values; + bool is_single_tensor_return = false; + std::unordered_map custom_obj_values = {}; + + public: + + const std::vector& get_inputs() const { + return inputs; + } + + const std::vector& get_outputs() const { + return outputs; + } + + const std::vector& get_nodes() const { + return nodes; + } + + const std::unordered_map& get_tensor_values() const { + return tensor_values; + } + + const std::unordered_map& get_sym_int_values() const { + return sym_int_values; + } + + const std::unordered_map& get_sym_bool_values() const { + return sym_bool_values; + } + + const bool& get_is_single_tensor_return() const { + return is_single_tensor_return; + } + + const std::unordered_map& get_custom_obj_values() const { + return custom_obj_values; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Graph& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, Graph& nlohmann_json_t); +}; + +class UserInputSpec { + private: + Argument arg; + + public: + + const Argument& get_arg() const { + return arg; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const UserInputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, UserInputSpec& nlohmann_json_t); +}; + +class ConstantValue { + struct Void {}; + + public: + enum class Tag { + AS_NONE, AS_INT, AS_FLOAT, AS_STRING, AS_BOOL + }; + + private: + std::variant, int64_t, double, std::string, bool> variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const std::tuple<>& get_as_none() const { + return std::get<1>(variant_); + } + + const int64_t& get_as_int() const { + return std::get<2>(variant_); + } + + const double& get_as_float() const { + return std::get<3>(variant_); + } + + const std::string& get_as_string() const { + return std::get<4>(variant_); + } + + const bool& get_as_bool() const { + return std::get<5>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ConstantValue& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::AS_NONE) { + nlohmann_json_j["as_none"] = nlohmann_json_t.get_as_none(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_INT) { + nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) { + nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_STRING) { + nlohmann_json_j["as_string"] = nlohmann_json_t.get_as_string(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { + nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, ConstantValue& nlohmann_json_t) { + + if (nlohmann_json_j.contains("as_none")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_none").template get>()); + nlohmann_json_t.tag_ = Tag::AS_NONE; + return; + } + if (nlohmann_json_j.contains("as_int")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_int").template get()); + nlohmann_json_t.tag_ = Tag::AS_INT; + return; + } + if (nlohmann_json_j.contains("as_float")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_float").template get()); + nlohmann_json_t.tag_ = Tag::AS_FLOAT; + return; + } + if (nlohmann_json_j.contains("as_string")) { + nlohmann_json_t.variant_.emplace<4>(nlohmann_json_j.at("as_string").template get()); + nlohmann_json_t.tag_ = Tag::AS_STRING; + return; + } + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<5>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; + return; + } + } +}; + +class ConstantInputSpec { + private: + std::string name; + ConstantValue value; + + public: + + const std::string& get_name() const { + return name; + } + + const ConstantValue& get_value() const { + return value; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ConstantInputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ConstantInputSpec& nlohmann_json_t); +}; + +class InputToParameterSpec { + private: + TensorArgument arg; + std::string parameter_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + const std::string& get_parameter_name() const { + return parameter_name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputToParameterSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputToParameterSpec& nlohmann_json_t); +}; + +class InputToBufferSpec { + private: + TensorArgument arg; + std::string buffer_name; + bool persistent; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + const std::string& get_buffer_name() const { + return buffer_name; + } + + const bool& get_persistent() const { + return persistent; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputToBufferSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputToBufferSpec& nlohmann_json_t); +}; + +class InputToTensorConstantSpec { + private: + TensorArgument arg; + std::string tensor_constant_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + const std::string& get_tensor_constant_name() const { + return tensor_constant_name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputToTensorConstantSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputToTensorConstantSpec& nlohmann_json_t); +}; + +class InputToCustomObjSpec { + private: + CustomObjArgument arg; + std::string custom_obj_name; + + public: + + const CustomObjArgument& get_arg() const { + return arg; + } + + const std::string& get_custom_obj_name() const { + return custom_obj_name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputToCustomObjSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputToCustomObjSpec& nlohmann_json_t); +}; + +class InputTokenSpec { + private: + TokenArgument arg; + + public: + + const TokenArgument& get_arg() const { + return arg; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputTokenSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, InputTokenSpec& nlohmann_json_t); +}; + +class InputSpec { + struct Void {}; + + public: + enum class Tag { + USER_INPUT, PARAMETER, BUFFER, TENSOR_CONSTANT, CUSTOM_OBJ, TOKEN, CONSTANT_INPUT + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const UserInputSpec& get_user_input() const { + return std::get<1>(variant_); + } + + const InputToParameterSpec& get_parameter() const { + return std::get<2>(variant_); + } + + const InputToBufferSpec& get_buffer() const { + return std::get<3>(variant_); + } + + const InputToTensorConstantSpec& get_tensor_constant() const { + return std::get<4>(variant_); + } + + const InputToCustomObjSpec& get_custom_obj() const { + return std::get<5>(variant_); + } + + const InputTokenSpec& get_token() const { + return std::get<6>(variant_); + } + + const ConstantInputSpec& get_constant_input() const { + return std::get<7>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const InputSpec& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::USER_INPUT) { + nlohmann_json_j["user_input"] = nlohmann_json_t.get_user_input(); + return; + } + if (nlohmann_json_t.tag_ == Tag::PARAMETER) { + nlohmann_json_j["parameter"] = nlohmann_json_t.get_parameter(); + return; + } + if (nlohmann_json_t.tag_ == Tag::BUFFER) { + nlohmann_json_j["buffer"] = nlohmann_json_t.get_buffer(); + return; + } + if (nlohmann_json_t.tag_ == Tag::TENSOR_CONSTANT) { + nlohmann_json_j["tensor_constant"] = nlohmann_json_t.get_tensor_constant(); + return; + } + if (nlohmann_json_t.tag_ == Tag::CUSTOM_OBJ) { + nlohmann_json_j["custom_obj"] = nlohmann_json_t.get_custom_obj(); + return; + } + if (nlohmann_json_t.tag_ == Tag::TOKEN) { + nlohmann_json_j["token"] = nlohmann_json_t.get_token(); + return; + } + if (nlohmann_json_t.tag_ == Tag::CONSTANT_INPUT) { + nlohmann_json_j["constant_input"] = nlohmann_json_t.get_constant_input(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, InputSpec& nlohmann_json_t) { + + if (nlohmann_json_j.contains("user_input")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("user_input").template get()); + nlohmann_json_t.tag_ = Tag::USER_INPUT; + return; + } + if (nlohmann_json_j.contains("parameter")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("parameter").template get()); + nlohmann_json_t.tag_ = Tag::PARAMETER; + return; + } + if (nlohmann_json_j.contains("buffer")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("buffer").template get()); + nlohmann_json_t.tag_ = Tag::BUFFER; + return; + } + if (nlohmann_json_j.contains("tensor_constant")) { + nlohmann_json_t.variant_.emplace<4>(nlohmann_json_j.at("tensor_constant").template get()); + nlohmann_json_t.tag_ = Tag::TENSOR_CONSTANT; + return; + } + if (nlohmann_json_j.contains("custom_obj")) { + nlohmann_json_t.variant_.emplace<5>(nlohmann_json_j.at("custom_obj").template get()); + nlohmann_json_t.tag_ = Tag::CUSTOM_OBJ; + return; + } + if (nlohmann_json_j.contains("token")) { + nlohmann_json_t.variant_.emplace<6>(nlohmann_json_j.at("token").template get()); + nlohmann_json_t.tag_ = Tag::TOKEN; + return; + } + if (nlohmann_json_j.contains("constant_input")) { + nlohmann_json_t.variant_.emplace<7>(nlohmann_json_j.at("constant_input").template get()); + nlohmann_json_t.tag_ = Tag::CONSTANT_INPUT; + return; + } + } +}; + +class UserOutputSpec { + private: + Argument arg; + + public: + + const Argument& get_arg() const { + return arg; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const UserOutputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, UserOutputSpec& nlohmann_json_t); +}; + +class LossOutputSpec { + private: + TensorArgument arg; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const LossOutputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, LossOutputSpec& nlohmann_json_t); +}; + +class BufferMutationSpec { + private: + TensorArgument arg; + std::string buffer_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + const std::string& get_buffer_name() const { + return buffer_name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const BufferMutationSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, BufferMutationSpec& nlohmann_json_t); +}; + +class GradientToParameterSpec { + private: + TensorArgument arg; + std::string parameter_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + const std::string& get_parameter_name() const { + return parameter_name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GradientToParameterSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GradientToParameterSpec& nlohmann_json_t); +}; + +class GradientToUserInputSpec { + private: + TensorArgument arg; + std::string user_input_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + const std::string& get_user_input_name() const { + return user_input_name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GradientToUserInputSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GradientToUserInputSpec& nlohmann_json_t); +}; + +class UserInputMutationSpec { + private: + TensorArgument arg; + std::string user_input_name; + + public: + + const TensorArgument& get_arg() const { + return arg; + } + + const std::string& get_user_input_name() const { + return user_input_name; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const UserInputMutationSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, UserInputMutationSpec& nlohmann_json_t); +}; + +class OutputTokenSpec { + private: + TokenArgument arg; + + public: + + const TokenArgument& get_arg() const { + return arg; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const OutputTokenSpec& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, OutputTokenSpec& nlohmann_json_t); +}; + +class OutputSpec { + struct Void {}; + + public: + enum class Tag { + USER_OUTPUT, LOSS_OUTPUT, BUFFER_MUTATION, GRADIENT_TO_PARAMETER, GRADIENT_TO_USER_INPUT, USER_INPUT_MUTATION, TOKEN + }; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const { + return tag_; + } + + const UserOutputSpec& get_user_output() const { + return std::get<1>(variant_); + } + + const LossOutputSpec& get_loss_output() const { + return std::get<2>(variant_); + } + + const BufferMutationSpec& get_buffer_mutation() const { + return std::get<3>(variant_); + } + + const GradientToParameterSpec& get_gradient_to_parameter() const { + return std::get<4>(variant_); + } + + const GradientToUserInputSpec& get_gradient_to_user_input() const { + return std::get<5>(variant_); + } + + const UserInputMutationSpec& get_user_input_mutation() const { + return std::get<6>(variant_); + } + + const OutputTokenSpec& get_token() const { + return std::get<7>(variant_); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const OutputSpec& nlohmann_json_t) { + + if (nlohmann_json_t.tag_ == Tag::USER_OUTPUT) { + nlohmann_json_j["user_output"] = nlohmann_json_t.get_user_output(); + return; + } + if (nlohmann_json_t.tag_ == Tag::LOSS_OUTPUT) { + nlohmann_json_j["loss_output"] = nlohmann_json_t.get_loss_output(); + return; + } + if (nlohmann_json_t.tag_ == Tag::BUFFER_MUTATION) { + nlohmann_json_j["buffer_mutation"] = nlohmann_json_t.get_buffer_mutation(); + return; + } + if (nlohmann_json_t.tag_ == Tag::GRADIENT_TO_PARAMETER) { + nlohmann_json_j["gradient_to_parameter"] = nlohmann_json_t.get_gradient_to_parameter(); + return; + } + if (nlohmann_json_t.tag_ == Tag::GRADIENT_TO_USER_INPUT) { + nlohmann_json_j["gradient_to_user_input"] = nlohmann_json_t.get_gradient_to_user_input(); + return; + } + if (nlohmann_json_t.tag_ == Tag::USER_INPUT_MUTATION) { + nlohmann_json_j["user_input_mutation"] = nlohmann_json_t.get_user_input_mutation(); + return; + } + if (nlohmann_json_t.tag_ == Tag::TOKEN) { + nlohmann_json_j["token"] = nlohmann_json_t.get_token(); + return; + } + } + + friend void from_json(const nlohmann::json& nlohmann_json_j, OutputSpec& nlohmann_json_t) { + + if (nlohmann_json_j.contains("user_output")) { + nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("user_output").template get()); + nlohmann_json_t.tag_ = Tag::USER_OUTPUT; + return; + } + if (nlohmann_json_j.contains("loss_output")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("loss_output").template get()); + nlohmann_json_t.tag_ = Tag::LOSS_OUTPUT; + return; + } + if (nlohmann_json_j.contains("buffer_mutation")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("buffer_mutation").template get()); + nlohmann_json_t.tag_ = Tag::BUFFER_MUTATION; + return; + } + if (nlohmann_json_j.contains("gradient_to_parameter")) { + nlohmann_json_t.variant_.emplace<4>(nlohmann_json_j.at("gradient_to_parameter").template get()); + nlohmann_json_t.tag_ = Tag::GRADIENT_TO_PARAMETER; + return; + } + if (nlohmann_json_j.contains("gradient_to_user_input")) { + nlohmann_json_t.variant_.emplace<5>(nlohmann_json_j.at("gradient_to_user_input").template get()); + nlohmann_json_t.tag_ = Tag::GRADIENT_TO_USER_INPUT; + return; + } + if (nlohmann_json_j.contains("user_input_mutation")) { + nlohmann_json_t.variant_.emplace<6>(nlohmann_json_j.at("user_input_mutation").template get()); + nlohmann_json_t.tag_ = Tag::USER_INPUT_MUTATION; + return; + } + if (nlohmann_json_j.contains("token")) { + nlohmann_json_t.variant_.emplace<7>(nlohmann_json_j.at("token").template get()); + nlohmann_json_t.tag_ = Tag::TOKEN; + return; + } + } +}; + +class GraphSignature { + private: + std::vector input_specs; + std::vector output_specs; + + public: + + const std::vector& get_input_specs() const { + return input_specs; + } + + const std::vector& get_output_specs() const { + return output_specs; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GraphSignature& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GraphSignature& nlohmann_json_t); +}; + +class RangeConstraint { + private: + std::optional min_val; + std::optional max_val; + + public: + + const std::optional& get_min_val() const { + return min_val; + } + + const std::optional& get_max_val() const { + return max_val; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const RangeConstraint& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, RangeConstraint& nlohmann_json_t); +}; + +class ModuleCallSignature { + private: + std::vector inputs; + std::vector outputs; + std::string in_spec; + std::string out_spec; + std::optional> forward_arg_names = std::nullopt; + + public: + + const std::vector& get_inputs() const { + return inputs; + } + + const std::vector& get_outputs() const { + return outputs; + } + + const std::string& get_in_spec() const { + return in_spec; + } + + const std::string& get_out_spec() const { + return out_spec; + } + + const std::optional>& get_forward_arg_names() const { + return forward_arg_names; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallSignature& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ModuleCallSignature& nlohmann_json_t); +}; + +class ModuleCallEntry { + private: + std::string fqn; + std::optional signature = std::nullopt; + + public: + + const std::string& get_fqn() const { + return fqn; + } + + const std::optional& get_signature() const { + return signature; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallEntry& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ModuleCallEntry& nlohmann_json_t); +}; + +class GraphModule { + private: + Graph graph; + GraphSignature signature; + std::vector module_call_graph; + std::unordered_map metadata = {}; + + public: + + const Graph& get_graph() const { + return graph; + } + + const GraphSignature& get_signature() const { + return signature; + } + + const std::vector& get_module_call_graph() const { + return module_call_graph; + } + + const std::unordered_map& get_metadata() const { + return metadata; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const GraphModule& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, GraphModule& nlohmann_json_t); +}; + +class SchemaVersion { + private: + int64_t major; + int64_t minor; + + public: + + const int64_t& get_major() const { + return major; + } + + const int64_t& get_minor() const { + return minor; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const SchemaVersion& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, SchemaVersion& nlohmann_json_t); +}; + +class ExportedProgram { + private: + GraphModule graph_module; + std::unordered_map opset_version; + std::unordered_map range_constraints; + SchemaVersion schema_version; + std::vector verifiers = {}; + std::string torch_version = "<=2.4"; + + public: + + const GraphModule& get_graph_module() const { + return graph_module; + } + + const std::unordered_map& get_opset_version() const { + return opset_version; + } + + const std::unordered_map& get_range_constraints() const { + return range_constraints; + } + + const SchemaVersion& get_schema_version() const { + return schema_version; + } + + const std::vector& get_verifiers() const { + return verifiers; + } + + const std::string& get_torch_version() const { + return torch_version; + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t); +}; + +inline void to_json(nlohmann::json& nlohmann_json_j, const BufferMutationSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["buffer_name"] = nlohmann_json_t.buffer_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, BufferMutationSpec& nlohmann_json_t) { + BufferMutationSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.buffer_name = nlohmann_json_j.value("buffer_name", nlohmann_json_default_obj.buffer_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const ConstantInputSpec& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["value"] = nlohmann_json_t.value; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, ConstantInputSpec& nlohmann_json_t) { + ConstantInputSpec nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.value = nlohmann_json_j.value("value", nlohmann_json_default_obj.value); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const CustomObjArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["class_fqn"] = nlohmann_json_t.class_fqn; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, CustomObjArgument& nlohmann_json_t) { + CustomObjArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.class_fqn = nlohmann_json_j.value("class_fqn", nlohmann_json_default_obj.class_fqn); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const Device& nlohmann_json_t) { + nlohmann_json_j["type"] = nlohmann_json_t.type; + nlohmann_json_j["index"] = nlohmann_json_t.index; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, Device& nlohmann_json_t) { + Device nlohmann_json_default_obj; + nlohmann_json_t.type = nlohmann_json_j.value("type", nlohmann_json_default_obj.type); + nlohmann_json_t.index = nlohmann_json_j.value("index", nlohmann_json_default_obj.index); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nlohmann_json_t) { + nlohmann_json_j["graph_module"] = nlohmann_json_t.graph_module; + nlohmann_json_j["opset_version"] = nlohmann_json_t.opset_version; + nlohmann_json_j["range_constraints"] = nlohmann_json_t.range_constraints; + nlohmann_json_j["schema_version"] = nlohmann_json_t.schema_version; + nlohmann_json_j["verifiers"] = nlohmann_json_t.verifiers; + nlohmann_json_j["torch_version"] = nlohmann_json_t.torch_version; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t) { + ExportedProgram nlohmann_json_default_obj; + nlohmann_json_t.graph_module = nlohmann_json_j.value("graph_module", nlohmann_json_default_obj.graph_module); + nlohmann_json_t.opset_version = nlohmann_json_j.value("opset_version", nlohmann_json_default_obj.opset_version); + nlohmann_json_t.range_constraints = nlohmann_json_j.value("range_constraints", nlohmann_json_default_obj.range_constraints); + nlohmann_json_t.schema_version = nlohmann_json_j.value("schema_version", nlohmann_json_default_obj.schema_version); + nlohmann_json_t.verifiers = nlohmann_json_j.value("verifiers", nlohmann_json_default_obj.verifiers); + nlohmann_json_t.torch_version = nlohmann_json_j.value("torch_version", nlohmann_json_default_obj.torch_version); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GradientToParameterSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["parameter_name"] = nlohmann_json_t.parameter_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GradientToParameterSpec& nlohmann_json_t) { + GradientToParameterSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.parameter_name = nlohmann_json_j.value("parameter_name", nlohmann_json_default_obj.parameter_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GradientToUserInputSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["user_input_name"] = nlohmann_json_t.user_input_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GradientToUserInputSpec& nlohmann_json_t) { + GradientToUserInputSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.user_input_name = nlohmann_json_j.value("user_input_name", nlohmann_json_default_obj.user_input_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const Graph& nlohmann_json_t) { + nlohmann_json_j["inputs"] = nlohmann_json_t.inputs; + nlohmann_json_j["outputs"] = nlohmann_json_t.outputs; + nlohmann_json_j["nodes"] = nlohmann_json_t.nodes; + nlohmann_json_j["tensor_values"] = nlohmann_json_t.tensor_values; + nlohmann_json_j["sym_int_values"] = nlohmann_json_t.sym_int_values; + nlohmann_json_j["sym_bool_values"] = nlohmann_json_t.sym_bool_values; + nlohmann_json_j["is_single_tensor_return"] = nlohmann_json_t.is_single_tensor_return; + nlohmann_json_j["custom_obj_values"] = nlohmann_json_t.custom_obj_values; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, Graph& nlohmann_json_t) { + Graph nlohmann_json_default_obj; + nlohmann_json_t.inputs = nlohmann_json_j.value("inputs", nlohmann_json_default_obj.inputs); + nlohmann_json_t.outputs = nlohmann_json_j.value("outputs", nlohmann_json_default_obj.outputs); + nlohmann_json_t.nodes = nlohmann_json_j.value("nodes", nlohmann_json_default_obj.nodes); + nlohmann_json_t.tensor_values = nlohmann_json_j.value("tensor_values", nlohmann_json_default_obj.tensor_values); + nlohmann_json_t.sym_int_values = nlohmann_json_j.value("sym_int_values", nlohmann_json_default_obj.sym_int_values); + nlohmann_json_t.sym_bool_values = nlohmann_json_j.value("sym_bool_values", nlohmann_json_default_obj.sym_bool_values); + nlohmann_json_t.is_single_tensor_return = nlohmann_json_j.value("is_single_tensor_return", nlohmann_json_default_obj.is_single_tensor_return); + nlohmann_json_t.custom_obj_values = nlohmann_json_j.value("custom_obj_values", nlohmann_json_default_obj.custom_obj_values); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GraphArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["graph"] = nlohmann_json_t.graph; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GraphArgument& nlohmann_json_t) { + GraphArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.graph = nlohmann_json_j.value("graph", nlohmann_json_default_obj.graph); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GraphModule& nlohmann_json_t) { + nlohmann_json_j["graph"] = nlohmann_json_t.graph; + nlohmann_json_j["signature"] = nlohmann_json_t.signature; + nlohmann_json_j["module_call_graph"] = nlohmann_json_t.module_call_graph; + nlohmann_json_j["metadata"] = nlohmann_json_t.metadata; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GraphModule& nlohmann_json_t) { + GraphModule nlohmann_json_default_obj; + nlohmann_json_t.graph = nlohmann_json_j.value("graph", nlohmann_json_default_obj.graph); + nlohmann_json_t.signature = nlohmann_json_j.value("signature", nlohmann_json_default_obj.signature); + nlohmann_json_t.module_call_graph = nlohmann_json_j.value("module_call_graph", nlohmann_json_default_obj.module_call_graph); + nlohmann_json_t.metadata = nlohmann_json_j.value("metadata", nlohmann_json_default_obj.metadata); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const GraphSignature& nlohmann_json_t) { + nlohmann_json_j["input_specs"] = nlohmann_json_t.input_specs; + nlohmann_json_j["output_specs"] = nlohmann_json_t.output_specs; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, GraphSignature& nlohmann_json_t) { + GraphSignature nlohmann_json_default_obj; + nlohmann_json_t.input_specs = nlohmann_json_j.value("input_specs", nlohmann_json_default_obj.input_specs); + nlohmann_json_t.output_specs = nlohmann_json_j.value("output_specs", nlohmann_json_default_obj.output_specs); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputToBufferSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["buffer_name"] = nlohmann_json_t.buffer_name; + nlohmann_json_j["persistent"] = nlohmann_json_t.persistent; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputToBufferSpec& nlohmann_json_t) { + InputToBufferSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.buffer_name = nlohmann_json_j.value("buffer_name", nlohmann_json_default_obj.buffer_name); + nlohmann_json_t.persistent = nlohmann_json_j.value("persistent", nlohmann_json_default_obj.persistent); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputToCustomObjSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["custom_obj_name"] = nlohmann_json_t.custom_obj_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputToCustomObjSpec& nlohmann_json_t) { + InputToCustomObjSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.custom_obj_name = nlohmann_json_j.value("custom_obj_name", nlohmann_json_default_obj.custom_obj_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputToParameterSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["parameter_name"] = nlohmann_json_t.parameter_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputToParameterSpec& nlohmann_json_t) { + InputToParameterSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.parameter_name = nlohmann_json_j.value("parameter_name", nlohmann_json_default_obj.parameter_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputToTensorConstantSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["tensor_constant_name"] = nlohmann_json_t.tensor_constant_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputToTensorConstantSpec& nlohmann_json_t) { + InputToTensorConstantSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.tensor_constant_name = nlohmann_json_j.value("tensor_constant_name", nlohmann_json_default_obj.tensor_constant_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const InputTokenSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, InputTokenSpec& nlohmann_json_t) { + InputTokenSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const LossOutputSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, LossOutputSpec& nlohmann_json_t) { + LossOutputSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallEntry& nlohmann_json_t) { + nlohmann_json_j["fqn"] = nlohmann_json_t.fqn; + nlohmann_json_j["signature"] = nlohmann_json_t.signature; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, ModuleCallEntry& nlohmann_json_t) { + ModuleCallEntry nlohmann_json_default_obj; + nlohmann_json_t.fqn = nlohmann_json_j.value("fqn", nlohmann_json_default_obj.fqn); + nlohmann_json_t.signature = nlohmann_json_j.value("signature", nlohmann_json_default_obj.signature); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallSignature& nlohmann_json_t) { + nlohmann_json_j["inputs"] = nlohmann_json_t.inputs; + nlohmann_json_j["outputs"] = nlohmann_json_t.outputs; + nlohmann_json_j["in_spec"] = nlohmann_json_t.in_spec; + nlohmann_json_j["out_spec"] = nlohmann_json_t.out_spec; + nlohmann_json_j["forward_arg_names"] = nlohmann_json_t.forward_arg_names; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, ModuleCallSignature& nlohmann_json_t) { + ModuleCallSignature nlohmann_json_default_obj; + nlohmann_json_t.inputs = nlohmann_json_j.value("inputs", nlohmann_json_default_obj.inputs); + nlohmann_json_t.outputs = nlohmann_json_j.value("outputs", nlohmann_json_default_obj.outputs); + nlohmann_json_t.in_spec = nlohmann_json_j.value("in_spec", nlohmann_json_default_obj.in_spec); + nlohmann_json_t.out_spec = nlohmann_json_j.value("out_spec", nlohmann_json_default_obj.out_spec); + nlohmann_json_t.forward_arg_names = nlohmann_json_j.value("forward_arg_names", nlohmann_json_default_obj.forward_arg_names); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const NamedArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, NamedArgument& nlohmann_json_t) { + NamedArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const Node& nlohmann_json_t) { + nlohmann_json_j["target"] = nlohmann_json_t.target; + nlohmann_json_j["inputs"] = nlohmann_json_t.inputs; + nlohmann_json_j["outputs"] = nlohmann_json_t.outputs; + nlohmann_json_j["metadata"] = nlohmann_json_t.metadata; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, Node& nlohmann_json_t) { + Node nlohmann_json_default_obj; + nlohmann_json_t.target = nlohmann_json_j.value("target", nlohmann_json_default_obj.target); + nlohmann_json_t.inputs = nlohmann_json_j.value("inputs", nlohmann_json_default_obj.inputs); + nlohmann_json_t.outputs = nlohmann_json_j.value("outputs", nlohmann_json_default_obj.outputs); + nlohmann_json_t.metadata = nlohmann_json_j.value("metadata", nlohmann_json_default_obj.metadata); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const OutputTokenSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, OutputTokenSpec& nlohmann_json_t) { + OutputTokenSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const RangeConstraint& nlohmann_json_t) { + nlohmann_json_j["min_val"] = nlohmann_json_t.min_val; + nlohmann_json_j["max_val"] = nlohmann_json_t.max_val; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, RangeConstraint& nlohmann_json_t) { + RangeConstraint nlohmann_json_default_obj; + nlohmann_json_t.min_val = nlohmann_json_j.value("min_val", nlohmann_json_default_obj.min_val); + nlohmann_json_t.max_val = nlohmann_json_j.value("max_val", nlohmann_json_default_obj.max_val); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const SchemaVersion& nlohmann_json_t) { + nlohmann_json_j["major"] = nlohmann_json_t.major; + nlohmann_json_j["minor"] = nlohmann_json_t.minor; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, SchemaVersion& nlohmann_json_t) { + SchemaVersion nlohmann_json_default_obj; + nlohmann_json_t.major = nlohmann_json_j.value("major", nlohmann_json_default_obj.major); + nlohmann_json_t.minor = nlohmann_json_j.value("minor", nlohmann_json_default_obj.minor); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const SymExpr& nlohmann_json_t) { + nlohmann_json_j["expr_str"] = nlohmann_json_t.expr_str; + nlohmann_json_j["hint"] = nlohmann_json_t.hint; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, SymExpr& nlohmann_json_t) { + SymExpr nlohmann_json_default_obj; + nlohmann_json_t.expr_str = nlohmann_json_j.value("expr_str", nlohmann_json_default_obj.expr_str); + nlohmann_json_t.hint = nlohmann_json_j.value("hint", nlohmann_json_default_obj.hint); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const TensorArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, TensorArgument& nlohmann_json_t) { + TensorArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const TensorMeta& nlohmann_json_t) { + nlohmann_json_j["dtype"] = nlohmann_json_t.dtype; + nlohmann_json_j["sizes"] = nlohmann_json_t.sizes; + nlohmann_json_j["requires_grad"] = nlohmann_json_t.requires_grad; + nlohmann_json_j["device"] = nlohmann_json_t.device; + nlohmann_json_j["strides"] = nlohmann_json_t.strides; + nlohmann_json_j["storage_offset"] = nlohmann_json_t.storage_offset; + nlohmann_json_j["layout"] = nlohmann_json_t.layout; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, TensorMeta& nlohmann_json_t) { + TensorMeta nlohmann_json_default_obj; + nlohmann_json_t.dtype = nlohmann_json_j.value("dtype", nlohmann_json_default_obj.dtype); + nlohmann_json_t.sizes = nlohmann_json_j.value("sizes", nlohmann_json_default_obj.sizes); + nlohmann_json_t.requires_grad = nlohmann_json_j.value("requires_grad", nlohmann_json_default_obj.requires_grad); + nlohmann_json_t.device = nlohmann_json_j.value("device", nlohmann_json_default_obj.device); + nlohmann_json_t.strides = nlohmann_json_j.value("strides", nlohmann_json_default_obj.strides); + nlohmann_json_t.storage_offset = nlohmann_json_j.value("storage_offset", nlohmann_json_default_obj.storage_offset); + nlohmann_json_t.layout = nlohmann_json_j.value("layout", nlohmann_json_default_obj.layout); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const TokenArgument& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, TokenArgument& nlohmann_json_t) { + TokenArgument nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const UserInputMutationSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; + nlohmann_json_j["user_input_name"] = nlohmann_json_t.user_input_name; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, UserInputMutationSpec& nlohmann_json_t) { + UserInputMutationSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); + nlohmann_json_t.user_input_name = nlohmann_json_j.value("user_input_name", nlohmann_json_default_obj.user_input_name); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const UserInputSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, UserInputSpec& nlohmann_json_t) { + UserInputSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +inline void to_json(nlohmann::json& nlohmann_json_j, const UserOutputSpec& nlohmann_json_t) { + nlohmann_json_j["arg"] = nlohmann_json_t.arg; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, UserOutputSpec& nlohmann_json_t) { + UserOutputSpec nlohmann_json_default_obj; + nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); +} + +} // namespace _export +} // namespace torch + +// clang-format on