mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[export] Add test to enforce consistency between synced thrift and generated thrift from schema.py (#141989)
Summary: In this diff we implement a way to ensure the internal thrift schema from cfgr (configerator/structs/caffe2/torch/export/schema.thrift) and the schema in OSS (torch/_export/serde/schema.thrift) are in sync, by adding a unittest to reflect on the type names and fields from each schema and compare them field by field. When we detect new fields/types from torch/_export/serde/schema.thrift, there'll be a test failure on the trunk and the error message hints people to add the missing field/type to the thrift schema from cfgr, so that they are always in sync in practice. Test Plan: buck test mode/opt caffe2/test:test_export -- -r test_thrift_schema_in_sync Differential Revision: D66716834 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141989 Approved by: https://github.com/yiming0416
This commit is contained in:
committed by
PyTorch MergeBot
parent
bab15df40a
commit
1a7da6e7e9
@ -58,7 +58,7 @@ if __name__ == "__main__":
|
||||
first_line = (
|
||||
"@" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py"
|
||||
)
|
||||
checksum = f"checksum<<{commit.checksum_result}>>"
|
||||
checksum = f"checksum<<{commit.checksum_next}>>"
|
||||
yaml_header = "# " + first_line
|
||||
yaml_header += "\n# " + checksum
|
||||
yaml_payload = dump(commit.result, Dumper=Dumper, sort_keys=False)
|
||||
@ -73,7 +73,7 @@ if __name__ == "__main__":
|
||||
yaml_content = yaml_header + "\n" + yaml_payload
|
||||
|
||||
thrift_schema = "// " + first_line
|
||||
thrift_schema += "\n// " + checksum
|
||||
thrift_schema += f"\n// checksum<<{commit.thrift_checksum_next}>>"
|
||||
thrift_schema += "\n" + commit.thrift_schema
|
||||
|
||||
if args.dry_run:
|
||||
|
1
setup.py
1
setup.py
@ -1338,6 +1338,7 @@ def main():
|
||||
"_inductor/codegen/*.h",
|
||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||
"_export/serde/*.yaml",
|
||||
"_export/serde/*.thrift",
|
||||
"share/cmake/ATen/*.cmake",
|
||||
"share/cmake/Caffe2/*.cmake",
|
||||
"share/cmake/Caffe2/public/*.cmake",
|
||||
|
@ -26,7 +26,27 @@ Example(s):
|
||||
except SchemaUpdateError as e:
|
||||
self.fail(f"Failed to update schema: {e}\n{msg}")
|
||||
|
||||
self.assertEqual(commit.checksum_base, commit.checksum_result, msg)
|
||||
self.assertEqual(commit.checksum_head, commit.checksum_next, msg)
|
||||
|
||||
def test_thrift_schema_unchanged(self):
|
||||
msg = """
|
||||
Detected an unexpected change to schema.thrift. Please update schema.py instead and run the following script:
|
||||
Example(s):
|
||||
python scripts/export/update_schema.py --prefix <path_to_torch_development_diretory>
|
||||
"""
|
||||
|
||||
if IS_FBCODE:
|
||||
msg += """or
|
||||
buck run caffe2:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/
|
||||
"""
|
||||
|
||||
try:
|
||||
commit = update_schema()
|
||||
except SchemaUpdateError as e:
|
||||
self.fail(f"Failed to update schema: {e}\n{msg}")
|
||||
|
||||
self.assertEqual(commit.thrift_checksum_head, commit.thrift_checksum_real, msg)
|
||||
self.assertEqual(commit.thrift_checksum_head, commit.thrift_checksum_next, msg)
|
||||
|
||||
def test_schema_diff(self):
|
||||
additions, subtractions = _diff_schema(
|
||||
@ -105,14 +125,17 @@ Example(s):
|
||||
|
||||
commit = _Commit(
|
||||
result=src,
|
||||
checksum_result="",
|
||||
checksum_next="",
|
||||
yaml_path="",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base="",
|
||||
checksum_head="",
|
||||
cpp_header="",
|
||||
cpp_header_path="",
|
||||
thrift_checksum_head="",
|
||||
thrift_checksum_real="",
|
||||
thrift_checksum_next="",
|
||||
thrift_schema="",
|
||||
thrift_schema_path="",
|
||||
)
|
||||
@ -141,14 +164,17 @@ Example(s):
|
||||
|
||||
commit = _Commit(
|
||||
result=src,
|
||||
checksum_result="",
|
||||
checksum_next="",
|
||||
yaml_path="",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base="",
|
||||
checksum_head="",
|
||||
cpp_header="",
|
||||
cpp_header_path="",
|
||||
thrift_checksum_head="",
|
||||
thrift_checksum_real="",
|
||||
thrift_checksum_next="",
|
||||
thrift_schema="",
|
||||
thrift_schema_path="",
|
||||
)
|
||||
@ -180,14 +206,17 @@ Example(s):
|
||||
|
||||
commit = _Commit(
|
||||
result=src,
|
||||
checksum_result="",
|
||||
checksum_next="",
|
||||
yaml_path="",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base="",
|
||||
checksum_head="",
|
||||
cpp_header="",
|
||||
cpp_header_path="",
|
||||
thrift_checksum_head="",
|
||||
thrift_checksum_real="",
|
||||
thrift_checksum_next="",
|
||||
thrift_schema="",
|
||||
thrift_schema_path="",
|
||||
)
|
||||
@ -242,14 +271,17 @@ Example(s):
|
||||
|
||||
commit = _Commit(
|
||||
result=src,
|
||||
checksum_result="",
|
||||
checksum_next="",
|
||||
yaml_path="",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base="",
|
||||
checksum_head="",
|
||||
cpp_header="",
|
||||
cpp_header_path="",
|
||||
thrift_checksum_head="",
|
||||
thrift_checksum_real="",
|
||||
thrift_checksum_next="",
|
||||
thrift_schema="",
|
||||
thrift_schema_path="",
|
||||
)
|
||||
@ -274,14 +306,17 @@ Example(s):
|
||||
|
||||
commit = _Commit(
|
||||
result=src,
|
||||
checksum_result="",
|
||||
checksum_next="",
|
||||
yaml_path="",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base="",
|
||||
checksum_head="",
|
||||
cpp_header="",
|
||||
cpp_header_path="",
|
||||
thrift_checksum_head="",
|
||||
thrift_checksum_real="",
|
||||
thrift_checksum_next="",
|
||||
thrift_schema="",
|
||||
thrift_schema_path="",
|
||||
)
|
||||
@ -313,14 +348,17 @@ Example(s):
|
||||
|
||||
commit = _Commit(
|
||||
result=src,
|
||||
checksum_result="",
|
||||
checksum_next="",
|
||||
yaml_path="",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base="",
|
||||
checksum_head="",
|
||||
cpp_header="",
|
||||
cpp_header_path="",
|
||||
thrift_checksum_head="",
|
||||
thrift_checksum_real="",
|
||||
thrift_checksum_next="",
|
||||
thrift_schema="",
|
||||
thrift_schema_path="",
|
||||
)
|
||||
@ -349,14 +387,17 @@ Example(s):
|
||||
|
||||
commit = _Commit(
|
||||
result=src,
|
||||
checksum_result="",
|
||||
checksum_next="",
|
||||
yaml_path="",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base="",
|
||||
checksum_head="",
|
||||
cpp_header="",
|
||||
cpp_header_path="",
|
||||
thrift_checksum_head="",
|
||||
thrift_checksum_real="",
|
||||
thrift_checksum_next="",
|
||||
thrift_schema="",
|
||||
thrift_schema_path="",
|
||||
)
|
||||
|
@ -58,8 +58,8 @@ class Device:
|
||||
@dataclass(repr=False)
|
||||
class SymExprHint(_Union):
|
||||
as_int: Annotated[int, 10]
|
||||
as_float: Annotated[float, 20]
|
||||
as_bool: Annotated[bool, 30]
|
||||
as_bool: Annotated[bool, 20]
|
||||
as_float: Annotated[float, 30]
|
||||
|
||||
|
||||
# This is for storing the symbolic expressions behind symints/symfloats/symbools
|
||||
|
@ -1,7 +1,7 @@
|
||||
// @generated by update_schema.py
|
||||
// checksum<<4d7fed9eff0dc31422e15dc73bd5d4d31b2feba660d85d9d0a35881670166ebb>>
|
||||
// checksum<<0e89c5e620ad16c05bfe4fa2060ad43dcb0938dc31d77faad36b92f216c2c903>>
|
||||
|
||||
namespace py3 torch._export.schema
|
||||
namespace py3 torch._export
|
||||
namespace cpp2 torch._export.schema
|
||||
|
||||
enum Layout {
|
||||
@ -51,8 +51,8 @@ struct Device {
|
||||
|
||||
union SymExprHint {
|
||||
10: i64 as_int;
|
||||
20: double as_float;
|
||||
30: bool as_bool;
|
||||
20: bool as_bool;
|
||||
30: double as_float;
|
||||
}
|
||||
|
||||
struct SymExpr {
|
||||
|
@ -1,5 +1,5 @@
|
||||
# @generated by update_schema.py
|
||||
# checksum<<4d7fed9eff0dc31422e15dc73bd5d4d31b2feba660d85d9d0a35881670166ebb>>
|
||||
# checksum<<0335ca6e44a8a815ea638d538de0ad4f78a644af2689f6e93c0e8219117466e7>>
|
||||
Argument:
|
||||
kind: union
|
||||
fields:
|
||||
@ -380,10 +380,10 @@ SymExprHint:
|
||||
fields:
|
||||
as_int:
|
||||
type: int
|
||||
as_float:
|
||||
type: float
|
||||
as_bool:
|
||||
type: bool
|
||||
as_float:
|
||||
type: float
|
||||
SymFloat:
|
||||
kind: union
|
||||
fields:
|
||||
|
@ -31,7 +31,7 @@ def _staged_schema():
|
||||
thrift_type_defs: Dict[str, str] = {}
|
||||
|
||||
def _handle_aggregate(ty) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
def dump_type(t) -> Tuple[str, str, str]:
|
||||
def dump_type(t, level: int) -> Tuple[str, str, str]:
|
||||
CPP_TYPE_MAP = {
|
||||
str: "std::string",
|
||||
int: "int64_t",
|
||||
@ -90,20 +90,21 @@ def _staged_schema():
|
||||
"",
|
||||
)
|
||||
elif o == Union:
|
||||
assert level == 0, "Optional is only supported at the top level."
|
||||
args = typing.get_args(t)
|
||||
assert len(args) == 2 and args[1] == type(None)
|
||||
yaml_type, cpp_type, thrift_type = dump_type(args[0])
|
||||
yaml_type, cpp_type, thrift_type = dump_type(args[0], level + 1)
|
||||
return (
|
||||
f"Optional[{yaml_type}]",
|
||||
f"std::optional<{cpp_type}>",
|
||||
f"optional {thrift_type}",
|
||||
)
|
||||
elif o == Annotated:
|
||||
return dump_type(t.__origin__)
|
||||
return dump_type(t.__origin__, level)
|
||||
else:
|
||||
raise AssertionError(f"Type {t} is not supported in export schema.")
|
||||
yaml_arg_types, cpp_arg_types, thrift_arg_types = zip(
|
||||
*[dump_type(x) for x in typing.get_args(t)]
|
||||
*[dump_type(x, level + 1) for x in typing.get_args(t)]
|
||||
)
|
||||
return (
|
||||
(f"{yaml_head}[{', '.join(yaml_arg_types)}]"),
|
||||
@ -136,7 +137,7 @@ def _staged_schema():
|
||||
)
|
||||
|
||||
def dump_field(f) -> Tuple[Dict[str, Any], str, Optional[str], str, int]:
|
||||
t, cpp_type, thrift_type = dump_type(f.type)
|
||||
t, cpp_type, thrift_type = dump_type(f.type, 0)
|
||||
ret = {"type": t}
|
||||
cpp_default: Optional[str] = None
|
||||
assert (
|
||||
@ -455,7 +456,7 @@ void from_json(const nlohmann::json& j, ForwardRef<T>& p) {{
|
||||
}} // namespace torch
|
||||
"""
|
||||
thrift_schema = f"""
|
||||
namespace py3 torch._export.schema
|
||||
namespace py3 torch._export
|
||||
namespace cpp2 torch._export.schema
|
||||
{chr(10).join(thrift_enum_defs)}
|
||||
{chr(10).join(dict(sorted(thrift_type_defs.items(), key=lambda x: class_ordering[x[0]])).values())}
|
||||
@ -528,21 +529,24 @@ def _diff_schema(dst, src):
|
||||
return additions, subtractions
|
||||
|
||||
|
||||
def _hash_schema(s):
|
||||
return hashlib.sha256(repr(s).encode("utf-8")).hexdigest()
|
||||
def _hash_content(s: str):
|
||||
return hashlib.sha256(s.strip().encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Commit:
|
||||
result: Dict[str, Any]
|
||||
checksum_result: str
|
||||
checksum_next: str
|
||||
yaml_path: str
|
||||
additions: Dict[str, Any]
|
||||
subtractions: Dict[str, Any]
|
||||
base: Dict[str, Any]
|
||||
checksum_base: Optional[str]
|
||||
checksum_head: Optional[str]
|
||||
cpp_header: str
|
||||
cpp_header_path: str
|
||||
thrift_checksum_head: Optional[str]
|
||||
thrift_checksum_real: Optional[str]
|
||||
thrift_checksum_next: str
|
||||
thrift_schema: str
|
||||
thrift_schema_path: str
|
||||
|
||||
@ -555,13 +559,26 @@ def update_schema():
|
||||
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content)
|
||||
_check(match is not None, "checksum not found in schema.yaml")
|
||||
assert match is not None
|
||||
checksum_base = match.group(1)
|
||||
checksum_head = match.group(1)
|
||||
|
||||
thrift_content = importlib.resources.read_text(__package__, "schema.thrift")
|
||||
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", thrift_content)
|
||||
_check(match is not None, "checksum not found in schema.thrift")
|
||||
assert match is not None
|
||||
thrift_checksum_head = match.group(1)
|
||||
thrift_content = thrift_content.splitlines()
|
||||
assert thrift_content[0].startswith("// @" + "generated")
|
||||
assert thrift_content[1].startswith("// checksum<<")
|
||||
thrift_checksum_real = _hash_content("\n".join(thrift_content[2:]))
|
||||
|
||||
from yaml import load, Loader
|
||||
|
||||
dst = load(content, Loader=Loader)
|
||||
assert isinstance(dst, dict)
|
||||
else:
|
||||
checksum_base = None
|
||||
checksum_head = None
|
||||
thrift_checksum_head = None
|
||||
thrift_checksum_real = None
|
||||
dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None}
|
||||
|
||||
src, cpp_header, thrift_schema = _staged_schema()
|
||||
@ -574,14 +591,17 @@ def update_schema():
|
||||
|
||||
return _Commit(
|
||||
result=src,
|
||||
checksum_result=_hash_schema(src),
|
||||
checksum_next=_hash_content(repr(src)),
|
||||
yaml_path=yaml_path,
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base=checksum_base,
|
||||
checksum_head=checksum_head,
|
||||
cpp_header=cpp_header,
|
||||
cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h",
|
||||
thrift_checksum_head=thrift_checksum_head,
|
||||
thrift_checksum_real=thrift_checksum_real,
|
||||
thrift_checksum_next=_hash_content(thrift_schema),
|
||||
thrift_schema=thrift_schema,
|
||||
thrift_schema_path=thrift_schema_path,
|
||||
)
|
||||
|
30
torch/csrc/utils/generated_serialization_types.h
generated
30
torch/csrc/utils/generated_serialization_types.h
generated
@ -1,5 +1,5 @@
|
||||
// @generated by update_schema.py
|
||||
// checksum<<4d7fed9eff0dc31422e15dc73bd5d4d31b2feba660d85d9d0a35881670166ebb>>
|
||||
// checksum<<0335ca6e44a8a815ea638d538de0ad4f78a644af2689f6e93c0e8219117466e7>>
|
||||
// clang-format off
|
||||
|
||||
#pragma once
|
||||
@ -191,11 +191,11 @@ class SymExprHint {
|
||||
|
||||
public:
|
||||
enum class Tag {
|
||||
AS_INT, AS_FLOAT, AS_BOOL
|
||||
AS_INT, AS_BOOL, AS_FLOAT
|
||||
};
|
||||
|
||||
private:
|
||||
std::variant<Void, int64_t, double, bool> variant_;
|
||||
std::variant<Void, int64_t, bool, double> variant_;
|
||||
Tag tag_;
|
||||
|
||||
public:
|
||||
@ -207,11 +207,11 @@ class SymExprHint {
|
||||
return std::get<1>(variant_);
|
||||
}
|
||||
|
||||
const double& get_as_float() const {
|
||||
const bool& get_as_bool() const {
|
||||
return std::get<2>(variant_);
|
||||
}
|
||||
|
||||
const bool& get_as_bool() const {
|
||||
const double& get_as_float() const {
|
||||
return std::get<3>(variant_);
|
||||
}
|
||||
|
||||
@ -221,14 +221,14 @@ class SymExprHint {
|
||||
nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int();
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) {
|
||||
nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float();
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_t.tag_ == Tag::AS_BOOL) {
|
||||
nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool();
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) {
|
||||
nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
friend void from_json(const nlohmann::json& nlohmann_json_j, SymExprHint& nlohmann_json_t) {
|
||||
@ -238,14 +238,14 @@ class SymExprHint {
|
||||
nlohmann_json_t.tag_ = Tag::AS_INT;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_float")) {
|
||||
nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_float").template get<double>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_FLOAT;
|
||||
if (nlohmann_json_j.contains("as_bool")) {
|
||||
nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_bool").template get<bool>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_BOOL;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("as_bool")) {
|
||||
nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_bool").template get<bool>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_BOOL;
|
||||
if (nlohmann_json_j.contains("as_float")) {
|
||||
nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_float").template get<double>());
|
||||
nlohmann_json_t.tag_ = Tag::AS_FLOAT;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user