mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Differential Revision: D63206258 This diff introduces a mechanism to generate a json-compatible deserializer in cpp using nlohmann json (already being used by AOTI). Why we need this? Because there will be a lot of cases where people don't want to use Python to load the graph (e.g. cpp runtime), and instead they can use this header to deserialize the JSON graph. Every time we call update_schema.py to update the schema, the header will be auto generated and included into the source files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136398 Approved by: https://github.com/angelayi
83 lines
3.0 KiB
Python
83 lines
3.0 KiB
Python
import argparse
|
|
import os
|
|
|
|
from yaml import dump, Dumper
|
|
|
|
from torch._export.serde import schema_check
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(prog="update_schema")
|
|
parser.add_argument(
|
|
"--prefix", type=str, required=True, help="The root of pytorch directory."
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Print the schema instead of writing it to file.",
|
|
)
|
|
parser.add_argument(
|
|
"--force-unsafe",
|
|
action="store_true",
|
|
help="!!! Only use this option when you are a chad. !!! Force to write the schema even if schema validation doesn't pass.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
assert os.path.exists(
|
|
args.prefix
|
|
), f"Assuming path {args.prefix} is the root of pytorch directory, but it doesn't exist."
|
|
|
|
commit = schema_check.update_schema()
|
|
|
|
if os.path.exists(args.prefix + commit.yaml_path):
|
|
if commit.result["SCHEMA_VERSION"] < commit.base["SCHEMA_VERSION"]:
|
|
raise RuntimeError(
|
|
f"Schema version downgraded from {commit.base['SCHEMA_VERSION']} to {commit.result['SCHEMA_VERSION']}."
|
|
)
|
|
|
|
if commit.result["TREESPEC_VERSION"] < commit.base["TREESPEC_VERSION"]:
|
|
raise RuntimeError(
|
|
f"Treespec version downgraded from {commit.base['TREESPEC_VERSION']} to {commit.result['TREESPEC_VERSION']}."
|
|
)
|
|
else:
|
|
assert args.force_unsafe, "Existing schema yaml file not found, please use --force-unsafe to try again."
|
|
|
|
next_version, reason = schema_check.check(commit, args.force_unsafe)
|
|
|
|
if next_version is not None and next_version != commit.result["SCHEMA_VERSION"]:
|
|
raise RuntimeError(
|
|
f"Schema version is not updated from {commit.base['SCHEMA_VERSION']} to {next_version}.\n"
|
|
+ "Please either:\n"
|
|
+ " 1. update schema.py to not break compatibility.\n"
|
|
+ " or 2. bump the schema version to the expected value.\n"
|
|
+ " or 3. use --force-unsafe to override schema.yaml (not recommended).\n "
|
|
+ "and try again.\n"
|
|
+ f"Reason: {reason}"
|
|
)
|
|
|
|
first_line = (
|
|
"@" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py"
|
|
)
|
|
checksum = f"checksum<<{commit.checksum_result}>>"
|
|
yaml_header = "# " + first_line
|
|
yaml_header += "\n# " + checksum
|
|
yaml_payload = dump(commit.result, Dumper=Dumper, sort_keys=False)
|
|
|
|
cpp_header = "// " + first_line
|
|
cpp_header += "\n// " + checksum
|
|
cpp_header += "\n// clang-format off"
|
|
cpp_header += "\n" + commit.cpp_header
|
|
cpp_header += "\n// clang-format on"
|
|
cpp_header += "\n"
|
|
|
|
yaml_content = yaml_header + "\n" + yaml_payload
|
|
|
|
if args.dry_run:
|
|
print(yaml_content)
|
|
print("\nWill write the above schema to" + args.prefix + commit.yaml_path)
|
|
else:
|
|
with open(args.prefix + commit.yaml_path, "w") as f:
|
|
f.write(yaml_content)
|
|
with open(args.prefix + commit.cpp_header_path, "w") as f:
|
|
f.write(cpp_header)
|