Tweak schema_check to handle annotated builtin types (#145154)

As of python 3.9 annotated lists can be written as `list[T]` and `List[T]` has been deprecated.  However schema_check was converting `list[T]` to simply be `list`. This change teaches it to handle `list[T]` the same as `List[T]`.

A couple small drive-by changes I noticed as well:
- Path concatenation should use `os.path.join`, not `+`
- Spelling in error message

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145154
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-18 08:47:47 -08:00
committed by PyTorch MergeBot
parent 9e0437a04a
commit cd8d0fa20c
3 changed files with 26 additions and 26 deletions

View File

@ -80,9 +80,9 @@ if __name__ == "__main__":
print(yaml_content)
print("\nWill write the above schema to" + args.prefix + commit.yaml_path)
else:
with open(args.prefix + commit.yaml_path, "w") as f:
with open(os.path.join(args.prefix, commit.yaml_path), "w") as f:
f.write(yaml_content)
with open(args.prefix + commit.cpp_header_path, "w") as f:
with open(os.path.join(args.prefix, commit.cpp_header_path), "w") as f:
f.write(cpp_header)
with open(args.prefix + commit.thrift_schema_path, "w") as f:
with open(os.path.join(args.prefix, commit.thrift_schema_path), "w") as f:
f.write(thrift_schema)

View File

@ -14,7 +14,7 @@ class TestSchema(TestCase):
msg = """
Detected an invalidated change to export schema. Please run the following script to update the schema:
Example(s):
python scripts/export/update_schema.py --prefix <path_to_torch_development_diretory>
python scripts/export/update_schema.py --prefix <path_to_torch_development_directory>
"""
if IS_FBCODE:
@ -32,7 +32,7 @@ Example(s):
msg = """
Detected an unexpected change to schema.thrift. Please update schema.py instead and run the following script:
Example(s):
python scripts/export/update_schema.py --prefix <path_to_torch_development_diretory>
python scripts/export/update_schema.py --prefix <path_to_torch_development_directory>
"""
if IS_FBCODE:

View File

@ -20,6 +20,21 @@ def _check(x, msg):
raise SchemaUpdateError(msg)
_CPP_TYPE_MAP = {
str: "std::string",
int: "int64_t",
float: "double",
bool: "bool",
}
_THRIFT_TYPE_MAP = {
str: "string",
int: "i64",
float: "double",
bool: "bool",
}
def _staged_schema():
yaml_ret: Dict[str, Any] = {}
defs = {}
@ -32,27 +47,10 @@ def _staged_schema():
def _handle_aggregate(ty) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
def dump_type(t, level: int) -> tuple[str, str, str]:
CPP_TYPE_MAP = {
str: "std::string",
int: "int64_t",
float: "double",
bool: "bool",
}
THRIFT_TYPE_MAP = {
str: "string",
int: "i64",
float: "double",
bool: "bool",
}
if isinstance(t, type):
if t.__name__ in cpp_enum_defs:
return t.__name__, "int64_t", t.__name__
else:
return (
t.__name__,
CPP_TYPE_MAP.get(t, t.__name__),
THRIFT_TYPE_MAP.get(t, t.__name__),
)
if getattr(t, "__name__", None) in cpp_enum_defs:
return t.__name__, "int64_t", t.__name__
elif t in _CPP_TYPE_MAP:
return (t.__name__, _CPP_TYPE_MAP[t], _THRIFT_TYPE_MAP[t])
elif isinstance(t, str):
assert t in defs
assert t not in cpp_enum_defs
@ -102,6 +100,8 @@ def _staged_schema():
(f"{cpp_head}<{', '.join(cpp_arg_types)}>"),
f"{thrift_head}{', '.join(thrift_arg_types)}{thrift_tail}",
)
elif isinstance(t, type):
return (t.__name__, t.__name__, t.__name__)
else:
raise AssertionError(f"Type {t} is not supported in export schema.")