From cd8d0fa20c4bc865cbf153d6bf7ae375bbe62afd Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sat, 18 Jan 2025 08:47:47 -0800 Subject: [PATCH] Tweak schema_check to handle annotated builtin types (#145154) As of python 3.9 annotated lists can be written as `list[T]` and `List[T]` has been deprecated. However schema_check was converting `list[T]` to simply be `list`. This change teaches it to handle `list[T]` the same as `List[T]`. A couple small drive-by changes I noticed as well: - Path concatenation should use `os.path.join`, not `+` - Spelling in error message Pull Request resolved: https://github.com/pytorch/pytorch/pull/145154 Approved by: https://github.com/bobrenjc93 --- scripts/export/update_schema.py | 6 ++--- test/export/test_schema.py | 4 +-- torch/_export/serde/schema_check.py | 42 ++++++++++++++--------------- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/scripts/export/update_schema.py b/scripts/export/update_schema.py index bc76e4b7bfc7..904cf2b7d8c3 100644 --- a/scripts/export/update_schema.py +++ b/scripts/export/update_schema.py @@ -80,9 +80,9 @@ if __name__ == "__main__": print(yaml_content) print("\nWill write the above schema to" + args.prefix + commit.yaml_path) else: - with open(args.prefix + commit.yaml_path, "w") as f: + with open(os.path.join(args.prefix, commit.yaml_path), "w") as f: f.write(yaml_content) - with open(args.prefix + commit.cpp_header_path, "w") as f: + with open(os.path.join(args.prefix, commit.cpp_header_path), "w") as f: f.write(cpp_header) - with open(args.prefix + commit.thrift_schema_path, "w") as f: + with open(os.path.join(args.prefix, commit.thrift_schema_path), "w") as f: f.write(thrift_schema) diff --git a/test/export/test_schema.py b/test/export/test_schema.py index fef9ee796d5e..27e8cd59f2da 100644 --- a/test/export/test_schema.py +++ b/test/export/test_schema.py @@ -14,7 +14,7 @@ class TestSchema(TestCase): msg = """ Detected an invalidated change to export schema. Please run the following script to update the schema: Example(s): - python scripts/export/update_schema.py --prefix + python scripts/export/update_schema.py --prefix """ if IS_FBCODE: @@ -32,7 +32,7 @@ Example(s): msg = """ Detected an unexpected change to schema.thrift. Please update schema.py instead and run the following script: Example(s): - python scripts/export/update_schema.py --prefix + python scripts/export/update_schema.py --prefix """ if IS_FBCODE: diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index 9aad486ba9fc..e0a45318f7a0 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -20,6 +20,21 @@ def _check(x, msg): raise SchemaUpdateError(msg) +_CPP_TYPE_MAP = { + str: "std::string", + int: "int64_t", + float: "double", + bool: "bool", +} + +_THRIFT_TYPE_MAP = { + str: "string", + int: "i64", + float: "double", + bool: "bool", +} + + def _staged_schema(): yaml_ret: Dict[str, Any] = {} defs = {} @@ -32,27 +47,10 @@ def _staged_schema(): def _handle_aggregate(ty) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: def dump_type(t, level: int) -> tuple[str, str, str]: - CPP_TYPE_MAP = { - str: "std::string", - int: "int64_t", - float: "double", - bool: "bool", - } - THRIFT_TYPE_MAP = { - str: "string", - int: "i64", - float: "double", - bool: "bool", - } - if isinstance(t, type): - if t.__name__ in cpp_enum_defs: - return t.__name__, "int64_t", t.__name__ - else: - return ( - t.__name__, - CPP_TYPE_MAP.get(t, t.__name__), - THRIFT_TYPE_MAP.get(t, t.__name__), - ) + if getattr(t, "__name__", None) in cpp_enum_defs: + return t.__name__, "int64_t", t.__name__ + elif t in _CPP_TYPE_MAP: + return (t.__name__, _CPP_TYPE_MAP[t], _THRIFT_TYPE_MAP[t]) elif isinstance(t, str): assert t in defs assert t not in cpp_enum_defs @@ -102,6 +100,8 @@ def _staged_schema(): (f"{cpp_head}<{', '.join(cpp_arg_types)}>"), f"{thrift_head}{', '.join(thrift_arg_types)}{thrift_tail}", ) + elif isinstance(t, type): + return (t.__name__, t.__name__, t.__name__) else: raise AssertionError(f"Type {t} is not supported in export schema.")