[export] Add test to enforce consistency between synced thrift and generated thrift from schema.py (#141989)

Summary:
In this diff we implement a way to ensure the internal thrift schema from cfgr (configerator/structs/caffe2/torch/export/schema.thrift) and the schema in OSS (torch/_export/serde/schema.thrift) are in sync, by adding a unittest to reflect on the type names and fields from each schema and compare them field by field.

When we detect new fields/types from torch/_export/serde/schema.thrift, there'll be a test failure on the trunk and the error message hints people to add the missing field/type to the thrift schema from cfgr, so that they are always in sync in practice.

Test Plan: buck test mode/opt caffe2/test:test_export -- -r test_thrift_schema_in_sync

Differential Revision: D66716834

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141989
Approved by: https://github.com/yiming0416
This commit is contained in:
Zhengxu Chen
2024-12-06 18:42:18 +00:00
committed by PyTorch MergeBot
parent bab15df40a
commit 1a7da6e7e9
8 changed files with 117 additions and 55 deletions

View File

@ -58,7 +58,7 @@ if __name__ == "__main__":
first_line = (
"@" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py"
)
checksum = f"checksum<<{commit.checksum_result}>>"
checksum = f"checksum<<{commit.checksum_next}>>"
yaml_header = "# " + first_line
yaml_header += "\n# " + checksum
yaml_payload = dump(commit.result, Dumper=Dumper, sort_keys=False)
@ -73,7 +73,7 @@ if __name__ == "__main__":
yaml_content = yaml_header + "\n" + yaml_payload
thrift_schema = "// " + first_line
thrift_schema += "\n// " + checksum
thrift_schema += f"\n// checksum<<{commit.thrift_checksum_next}>>"
thrift_schema += "\n" + commit.thrift_schema
if args.dry_run:

View File

@ -1338,6 +1338,7 @@ def main():
"_inductor/codegen/*.h",
"_inductor/codegen/aoti_runtime/*.cpp",
"_export/serde/*.yaml",
"_export/serde/*.thrift",
"share/cmake/ATen/*.cmake",
"share/cmake/Caffe2/*.cmake",
"share/cmake/Caffe2/public/*.cmake",

View File

@ -26,7 +26,27 @@ Example(s):
except SchemaUpdateError as e:
self.fail(f"Failed to update schema: {e}\n{msg}")
self.assertEqual(commit.checksum_base, commit.checksum_result, msg)
self.assertEqual(commit.checksum_head, commit.checksum_next, msg)
def test_thrift_schema_unchanged(self):
msg = """
Detected an unexpected change to schema.thrift. Please update schema.py instead and run the following script:
Example(s):
python scripts/export/update_schema.py --prefix <path_to_torch_development_diretory>
"""
if IS_FBCODE:
msg += """or
buck run caffe2:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/
"""
try:
commit = update_schema()
except SchemaUpdateError as e:
self.fail(f"Failed to update schema: {e}\n{msg}")
self.assertEqual(commit.thrift_checksum_head, commit.thrift_checksum_real, msg)
self.assertEqual(commit.thrift_checksum_head, commit.thrift_checksum_next, msg)
def test_schema_diff(self):
additions, subtractions = _diff_schema(
@ -105,14 +125,17 @@ Example(s):
commit = _Commit(
result=src,
checksum_result="",
checksum_next="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base="",
checksum_head="",
cpp_header="",
cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="",
thrift_schema_path="",
)
@ -141,14 +164,17 @@ Example(s):
commit = _Commit(
result=src,
checksum_result="",
checksum_next="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base="",
checksum_head="",
cpp_header="",
cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="",
thrift_schema_path="",
)
@ -180,14 +206,17 @@ Example(s):
commit = _Commit(
result=src,
checksum_result="",
checksum_next="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base="",
checksum_head="",
cpp_header="",
cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="",
thrift_schema_path="",
)
@ -242,14 +271,17 @@ Example(s):
commit = _Commit(
result=src,
checksum_result="",
checksum_next="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base="",
checksum_head="",
cpp_header="",
cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="",
thrift_schema_path="",
)
@ -274,14 +306,17 @@ Example(s):
commit = _Commit(
result=src,
checksum_result="",
checksum_next="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base="",
checksum_head="",
cpp_header="",
cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="",
thrift_schema_path="",
)
@ -313,14 +348,17 @@ Example(s):
commit = _Commit(
result=src,
checksum_result="",
checksum_next="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base="",
checksum_head="",
cpp_header="",
cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="",
thrift_schema_path="",
)
@ -349,14 +387,17 @@ Example(s):
commit = _Commit(
result=src,
checksum_result="",
checksum_next="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base="",
checksum_head="",
cpp_header="",
cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="",
thrift_schema_path="",
)

View File

@ -58,8 +58,8 @@ class Device:
@dataclass(repr=False)
class SymExprHint(_Union):
as_int: Annotated[int, 10]
as_float: Annotated[float, 20]
as_bool: Annotated[bool, 30]
as_bool: Annotated[bool, 20]
as_float: Annotated[float, 30]
# This is for storing the symbolic expressions behind symints/symfloats/symbools

View File

@ -1,7 +1,7 @@
// @generated by update_schema.py
// checksum<<4d7fed9eff0dc31422e15dc73bd5d4d31b2feba660d85d9d0a35881670166ebb>>
// checksum<<0e89c5e620ad16c05bfe4fa2060ad43dcb0938dc31d77faad36b92f216c2c903>>
namespace py3 torch._export.schema
namespace py3 torch._export
namespace cpp2 torch._export.schema
enum Layout {
@ -51,8 +51,8 @@ struct Device {
union SymExprHint {
10: i64 as_int;
20: double as_float;
30: bool as_bool;
20: bool as_bool;
30: double as_float;
}
struct SymExpr {

View File

@ -1,5 +1,5 @@
# @generated by update_schema.py
# checksum<<4d7fed9eff0dc31422e15dc73bd5d4d31b2feba660d85d9d0a35881670166ebb>>
# checksum<<0335ca6e44a8a815ea638d538de0ad4f78a644af2689f6e93c0e8219117466e7>>
Argument:
kind: union
fields:
@ -380,10 +380,10 @@ SymExprHint:
fields:
as_int:
type: int
as_float:
type: float
as_bool:
type: bool
as_float:
type: float
SymFloat:
kind: union
fields:

View File

@ -31,7 +31,7 @@ def _staged_schema():
thrift_type_defs: Dict[str, str] = {}
def _handle_aggregate(ty) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
def dump_type(t) -> Tuple[str, str, str]:
def dump_type(t, level: int) -> Tuple[str, str, str]:
CPP_TYPE_MAP = {
str: "std::string",
int: "int64_t",
@ -90,20 +90,21 @@ def _staged_schema():
"",
)
elif o == Union:
assert level == 0, "Optional is only supported at the top level."
args = typing.get_args(t)
assert len(args) == 2 and args[1] == type(None)
yaml_type, cpp_type, thrift_type = dump_type(args[0])
yaml_type, cpp_type, thrift_type = dump_type(args[0], level + 1)
return (
f"Optional[{yaml_type}]",
f"std::optional<{cpp_type}>",
f"optional {thrift_type}",
)
elif o == Annotated:
return dump_type(t.__origin__)
return dump_type(t.__origin__, level)
else:
raise AssertionError(f"Type {t} is not supported in export schema.")
yaml_arg_types, cpp_arg_types, thrift_arg_types = zip(
*[dump_type(x) for x in typing.get_args(t)]
*[dump_type(x, level + 1) for x in typing.get_args(t)]
)
return (
(f"{yaml_head}[{', '.join(yaml_arg_types)}]"),
@ -136,7 +137,7 @@ def _staged_schema():
)
def dump_field(f) -> Tuple[Dict[str, Any], str, Optional[str], str, int]:
t, cpp_type, thrift_type = dump_type(f.type)
t, cpp_type, thrift_type = dump_type(f.type, 0)
ret = {"type": t}
cpp_default: Optional[str] = None
assert (
@ -455,7 +456,7 @@ void from_json(const nlohmann::json& j, ForwardRef<T>& p) {{
}} // namespace torch
"""
thrift_schema = f"""
namespace py3 torch._export.schema
namespace py3 torch._export
namespace cpp2 torch._export.schema
{chr(10).join(thrift_enum_defs)}
{chr(10).join(dict(sorted(thrift_type_defs.items(), key=lambda x: class_ordering[x[0]])).values())}
@ -528,21 +529,24 @@ def _diff_schema(dst, src):
return additions, subtractions
def _hash_schema(s):
return hashlib.sha256(repr(s).encode("utf-8")).hexdigest()
def _hash_content(s: str):
return hashlib.sha256(s.strip().encode("utf-8")).hexdigest()
@dataclasses.dataclass
class _Commit:
result: Dict[str, Any]
checksum_result: str
checksum_next: str
yaml_path: str
additions: Dict[str, Any]
subtractions: Dict[str, Any]
base: Dict[str, Any]
checksum_base: Optional[str]
checksum_head: Optional[str]
cpp_header: str
cpp_header_path: str
thrift_checksum_head: Optional[str]
thrift_checksum_real: Optional[str]
thrift_checksum_next: str
thrift_schema: str
thrift_schema_path: str
@ -555,13 +559,26 @@ def update_schema():
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content)
_check(match is not None, "checksum not found in schema.yaml")
assert match is not None
checksum_base = match.group(1)
checksum_head = match.group(1)
thrift_content = importlib.resources.read_text(__package__, "schema.thrift")
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", thrift_content)
_check(match is not None, "checksum not found in schema.thrift")
assert match is not None
thrift_checksum_head = match.group(1)
thrift_content = thrift_content.splitlines()
assert thrift_content[0].startswith("// @" + "generated")
assert thrift_content[1].startswith("// checksum<<")
thrift_checksum_real = _hash_content("\n".join(thrift_content[2:]))
from yaml import load, Loader
dst = load(content, Loader=Loader)
assert isinstance(dst, dict)
else:
checksum_base = None
checksum_head = None
thrift_checksum_head = None
thrift_checksum_real = None
dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None}
src, cpp_header, thrift_schema = _staged_schema()
@ -574,14 +591,17 @@ def update_schema():
return _Commit(
result=src,
checksum_result=_hash_schema(src),
checksum_next=_hash_content(repr(src)),
yaml_path=yaml_path,
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base=checksum_base,
checksum_head=checksum_head,
cpp_header=cpp_header,
cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h",
thrift_checksum_head=thrift_checksum_head,
thrift_checksum_real=thrift_checksum_real,
thrift_checksum_next=_hash_content(thrift_schema),
thrift_schema=thrift_schema,
thrift_schema_path=thrift_schema_path,
)

View File

@ -1,5 +1,5 @@
// @generated by update_schema.py
// checksum<<4d7fed9eff0dc31422e15dc73bd5d4d31b2feba660d85d9d0a35881670166ebb>>
// checksum<<0335ca6e44a8a815ea638d538de0ad4f78a644af2689f6e93c0e8219117466e7>>
// clang-format off
#pragma once
@ -191,11 +191,11 @@ class SymExprHint {
public:
enum class Tag {
AS_INT, AS_FLOAT, AS_BOOL
AS_INT, AS_BOOL, AS_FLOAT
};
private:
std::variant<Void, int64_t, double, bool> variant_;
std::variant<Void, int64_t, bool, double> variant_;
Tag tag_;
public:
@ -207,11 +207,11 @@ class SymExprHint {
return std::get<1>(variant_);
}
const double& get_as_float() const {
const bool& get_as_bool() const {
return std::get<2>(variant_);
}
const bool& get_as_bool() const {
const double& get_as_float() const {
return std::get<3>(variant_);
}
@ -221,14 +221,14 @@ class SymExprHint {
nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int();
return;
}
if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) {
nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float();
return;
}
if (nlohmann_json_t.tag_ == Tag::AS_BOOL) {
nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool();
return;
}
if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) {
nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float();
return;
}
}
friend void from_json(const nlohmann::json& nlohmann_json_j, SymExprHint& nlohmann_json_t) {
@ -238,14 +238,14 @@ class SymExprHint {
nlohmann_json_t.tag_ = Tag::AS_INT;
return;
}
if (nlohmann_json_j.contains("as_float")) {
nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_float").template get<double>());
nlohmann_json_t.tag_ = Tag::AS_FLOAT;
if (nlohmann_json_j.contains("as_bool")) {
nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_bool").template get<bool>());
nlohmann_json_t.tag_ = Tag::AS_BOOL;
return;
}
if (nlohmann_json_j.contains("as_bool")) {
nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_bool").template get<bool>());
nlohmann_json_t.tag_ = Tag::AS_BOOL;
if (nlohmann_json_j.contains("as_float")) {
nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_float").template get<double>());
nlohmann_json_t.tag_ = Tag::AS_FLOAT;
return;
}
}