Files
pytorch/scripts/export/update_schema.py
Zhengxu Chen 3ef2dfc1ba [export] Implement cpp deserializer. (#136398)
Differential Revision: D63206258

This diff introduces a mechanism to generate a json-compatible deserializer in cpp using nlohmann json (already being used by AOTI).

Why we need this? Because there will be a lot of cases where people don't want to use Python to load the graph (e.g. cpp runtime), and instead they can use this header to deserialize the JSON graph.

Every time we call update_schema.py to update the schema, the header will be auto generated and included into the source files.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136398
Approved by: https://github.com/angelayi
2024-11-14 16:34:59 +00:00

83 lines
3.0 KiB
Python

import argparse
import os
from yaml import dump, Dumper
from torch._export.serde import schema_check
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog="update_schema")
parser.add_argument(
"--prefix", type=str, required=True, help="The root of pytorch directory."
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print the schema instead of writing it to file.",
)
parser.add_argument(
"--force-unsafe",
action="store_true",
help="!!! Only use this option when you are a chad. !!! Force to write the schema even if schema validation doesn't pass.",
)
args = parser.parse_args()
assert os.path.exists(
args.prefix
), f"Assuming path {args.prefix} is the root of pytorch directory, but it doesn't exist."
commit = schema_check.update_schema()
if os.path.exists(args.prefix + commit.yaml_path):
if commit.result["SCHEMA_VERSION"] < commit.base["SCHEMA_VERSION"]:
raise RuntimeError(
f"Schema version downgraded from {commit.base['SCHEMA_VERSION']} to {commit.result['SCHEMA_VERSION']}."
)
if commit.result["TREESPEC_VERSION"] < commit.base["TREESPEC_VERSION"]:
raise RuntimeError(
f"Treespec version downgraded from {commit.base['TREESPEC_VERSION']} to {commit.result['TREESPEC_VERSION']}."
)
else:
assert args.force_unsafe, "Existing schema yaml file not found, please use --force-unsafe to try again."
next_version, reason = schema_check.check(commit, args.force_unsafe)
if next_version is not None and next_version != commit.result["SCHEMA_VERSION"]:
raise RuntimeError(
f"Schema version is not updated from {commit.base['SCHEMA_VERSION']} to {next_version}.\n"
+ "Please either:\n"
+ " 1. update schema.py to not break compatibility.\n"
+ " or 2. bump the schema version to the expected value.\n"
+ " or 3. use --force-unsafe to override schema.yaml (not recommended).\n "
+ "and try again.\n"
+ f"Reason: {reason}"
)
first_line = (
"@" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py"
)
checksum = f"checksum<<{commit.checksum_result}>>"
yaml_header = "# " + first_line
yaml_header += "\n# " + checksum
yaml_payload = dump(commit.result, Dumper=Dumper, sort_keys=False)
cpp_header = "// " + first_line
cpp_header += "\n// " + checksum
cpp_header += "\n// clang-format off"
cpp_header += "\n" + commit.cpp_header
cpp_header += "\n// clang-format on"
cpp_header += "\n"
yaml_content = yaml_header + "\n" + yaml_payload
if args.dry_run:
print(yaml_content)
print("\nWill write the above schema to" + args.prefix + commit.yaml_path)
else:
with open(args.prefix + commit.yaml_path, "w") as f:
f.write(yaml_content)
with open(args.prefix + commit.cpp_header_path, "w") as f:
f.write(cpp_header)