Files
pytorch/torch/_export/serde/schema_check.py
Maggie Moss c855f8632e Pyrefly suppressions 7/n (#164913)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Almost there!

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
 INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164913
Approved by: https://github.com/oulgen
2025-10-08 07:27:17 +00:00

742 lines
24 KiB
Python

# mypy: allow-untyped-defs
import dataclasses
import hashlib
import inspect
import re
import typing
from enum import IntEnum
from typing import Annotated, Any, ForwardRef, Optional, Union
from torch._export.serde import schema
from torch._export.serde.union import _Union
class SchemaUpdateError(Exception):
pass
def _check(x, msg):
if not x:
raise SchemaUpdateError(msg)
_CPP_TYPE_MAP = {
str: "std::string",
int: "int64_t",
float: "F64",
bool: "bool",
}
_THRIFT_TYPE_MAP = {
str: "string",
int: "i64",
float: "double",
bool: "bool",
}
def _staged_schema():
yaml_ret: dict[str, Any] = {}
defs = {}
cpp_enum_defs: dict[str, str] = {}
cpp_class_defs: dict[str, str] = {}
cpp_type_decls: list[str] = []
cpp_json_defs: list[str] = []
thrift_enum_defs: list[str] = []
thrift_type_defs: dict[str, str] = {}
def _handle_aggregate(ty) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
def dump_type(t, level: int) -> tuple[str, str, str]:
if getattr(t, "__name__", None) in cpp_enum_defs:
return t.__name__, "int64_t", t.__name__
elif t in _CPP_TYPE_MAP:
return (t.__name__, _CPP_TYPE_MAP[t], _THRIFT_TYPE_MAP[t])
elif isinstance(t, str):
assert t in defs
assert t not in cpp_enum_defs
assert "[" not in t
return t, f"ForwardRef<{t}>", t
elif isinstance(t, ForwardRef):
return (
t.__forward_arg__,
f"ForwardRef<{t.__forward_arg__}>",
t.__forward_arg__,
)
elif o := typing.get_origin(t):
# Lemme know if there's a better way to do this.
if o == list:
yaml_head, cpp_head, thrift_head, thrift_tail = (
"List",
"std::vector",
"list<",
">",
)
elif o == dict:
yaml_head, cpp_head, thrift_head, thrift_tail = (
"Dict",
"std::unordered_map",
"map<",
">",
)
elif o == Union:
assert level == 0, "Optional is only supported at the top level."
args = typing.get_args(t)
assert len(args) == 2 and args[1] == type(None)
yaml_type, cpp_type, thrift_type = dump_type(args[0], level + 1)
return (
f"Optional[{yaml_type}]",
f"std::optional<{cpp_type}>",
f"optional {thrift_type}",
)
elif o == Annotated:
return dump_type(t.__origin__, level)
else:
raise AssertionError(f"Type {t} is not supported in export schema.")
yaml_arg_types, cpp_arg_types, thrift_arg_types = zip(
*[dump_type(x, level + 1) for x in typing.get_args(t)]
)
return (
(f"{yaml_head}[{', '.join(yaml_arg_types)}]"),
(f"{cpp_head}<{', '.join(cpp_arg_types)}>"),
f"{thrift_head}{', '.join(thrift_arg_types)}{thrift_tail}",
)
elif isinstance(t, type):
return (t.__name__, t.__name__, t.__name__)
else:
raise AssertionError(f"Type {t} is not supported in export schema.")
def dump_cpp_value(v) -> str:
if v is None:
return "std::nullopt"
elif v is True:
return "true"
elif v is False:
return "false"
elif v == {}:
return "{}"
elif v == []:
return "{}"
elif v == ():
return "{}"
elif isinstance(v, str):
return f'"{v}"'
else:
raise AssertionError(
f"Default value {v} is not supported yet in export schema."
)
def dump_field(f) -> tuple[dict[str, Any], str, Optional[str], str, int]:
t, cpp_type, thrift_type = dump_type(f.type, 0)
ret = {"type": t}
cpp_default: Optional[str] = None
assert typing.get_origin(f.type) == Annotated, (
f"Field {f.name} must be annotated with an integer id."
)
thrift_id = f.type.__metadata__[0]
assert type(thrift_id) is int, (
f"Field {f.name} must be annotated with an integer id."
)
value = dataclasses.MISSING
if f.default is not dataclasses.MISSING:
value = f.default
elif f.default_factory is not dataclasses.MISSING:
value = f.default_factory()
if value is not dataclasses.MISSING:
default = str(value)
ret["default"] = default
cpp_default = dump_cpp_value(value)
if t.startswith("Optional[") and value is not None:
raise AssertionError(
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
)
return ret, cpp_type, cpp_default, thrift_type, thrift_id
yaml_ret = {}
cpp_ret = {}
thrift_ret = {}
thrift_ids = set()
for f in dataclasses.fields(ty):
yaml_res, cpp_type, cpp_default, thrift_type, thrift_id = dump_field(f)
yaml_ret[f.name] = yaml_res
cpp_ret[f.name] = {"cpp_type": cpp_type, "cpp_default": cpp_default}
thrift_ret[f.name] = {"thrift_type": thrift_type, "thrift_id": thrift_id}
if thrift_id in thrift_ids:
raise AssertionError(
f"Duplicate thrift id {thrift_id} for field {f.name} in {ty.__name__}."
)
thrift_ids.add(thrift_id)
return yaml_ret, cpp_ret, thrift_ret
def _handle_int_enum(name, ty):
yaml_ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}}
cpp_enum_defs[name] = f"""
enum class {name} {{
{chr(10).join([f" {x.name} = {x.value}," for x in ty])}
}};
inline std::string_view printEnum(const {name}& e) {{
switch (e) {{
{chr(10).join([f" case {name}::{x.name}: return {chr(34)}{x.name}{chr(34)};" for x in ty])}
default:
throw std::runtime_error("Unknown enum value");
}}
}}
inline void parseEnum(std::string_view s, {name}& t) {{
{chr(10).join([f" if (s == {chr(34)}{x.name}{chr(34)}) {{ t = {name}::{x.name}; return; }}" for x in ty])}
throw std::runtime_error("Unknown enum value: " + std::string{{s}});
}}
"""
thrift_enum_defs.append(
f"""
enum {name} {{
{chr(10).join([f" {x.name} = {x.value}," for x in ty])}
}}
"""
)
def _handle_struct(name, ty):
fields, cpp_fields, thrift_fields = _handle_aggregate(ty)
yaml_ret[name] = {"kind": "struct", "fields": fields}
field_decls = "\n".join(
f" {f['cpp_type']} {name}{' = ' + f['cpp_default'] if f['cpp_default'] is not None else ''};"
for name, f in cpp_fields.items()
)
def accessor(name, ty):
type_name = fields[name]["type"]
if type_name in cpp_enum_defs:
return f"""
{type_name} get_{name}() const {{
return static_cast<{type_name}>({name});
}}
void set_{name}({type_name} def) {{
{name} = static_cast<int64_t>(def);
}}
"""
return f"""
const {ty}& get_{name}() const {{
return {name};
}}
void set_{name}({ty} def) {{
{name} = std::move(def);
}}
"""
to_json_decl = f"void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t)"
to_json_def = f"""{{
{chr(10).join([f' nlohmann_json_j["{name}"] = nlohmann_json_t.{name};' for name, f in cpp_fields.items()])}
}}
"""
from_json_decl = f"void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t)"
from_json_def = f"""{{
{name} nlohmann_json_default_obj;
{
chr(10).join(
[
f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});'
for name, f in cpp_fields.items()
]
)
}
}}
"""
cpp_class_defs[name] = f"""
class {name} {{
private:
{field_decls}
public:
{"".join([accessor(name, f["cpp_type"]) for name, f in cpp_fields.items()])}
friend {to_json_decl};
friend {from_json_decl};
}};
"""
cpp_json_defs.append(f"inline {to_json_decl} {to_json_def}")
cpp_json_defs.append(f"inline {from_json_decl} {from_json_def}")
cpp_type_decls.append(f"class {name};")
thrift_type_defs[name] = f"""
struct {name} {{
{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())}
}}"""
def _handle_union(name, ty):
fields, cpp_fields, thrift_fields = _handle_aggregate(ty)
yaml_ret[name] = {"kind": "union", "fields": fields}
def accessor(name, ty, idx):
return f"""
const {ty}& get_{name}() const {{
return std::get<{idx + 1}>(variant_);
}}
void set_{name}({ty} def) {{
variant_.emplace<{idx + 1}>(std::move(def));
tag_ = Tag::{name.upper()};
}}
"""
to_json_branches = "".join(
[
f"""
if (nlohmann_json_t.tag_ == Tag::{name.upper()}) {{
nlohmann_json_j["{name}"] = nlohmann_json_t.get_{name}();
return;
}}"""
for idx, (name, f) in enumerate(cpp_fields.items())
]
)
from_json_branches = "".join(
[
f"""
if (nlohmann_json_j.contains("{name}")) {{
nlohmann_json_t.variant_.emplace<{idx + 1}>(nlohmann_json_j.at("{name}").template get<{f["cpp_type"]}>());
nlohmann_json_t.tag_ = Tag::{name.upper()};
return;
}}"""
for idx, (name, f) in enumerate(cpp_fields.items())
]
)
cpp_class_defs[name] = f"""
class {name} {{
struct Void {{}};
public:
enum class Tag {{
{", ".join([name.upper() for name in cpp_fields])}
}};
private:
std::variant<Void, {", ".join(f["cpp_type"] for f in cpp_fields.values())}> variant_;
Tag tag_;
public:
Tag tag() const {{
return tag_;
}}
{"".join([accessor(name, f["cpp_type"], idx) for idx, (name, f) in enumerate(cpp_fields.items())])}
friend void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t) {{
{to_json_branches}
}}
friend void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t) {{
{from_json_branches}
}}
}};
inline std::string_view printEnum(const {name}::Tag& e) {{
switch (e) {{
{chr(10).join([f" case {name}::Tag::{x.upper()}: return {chr(34)}{x.upper()}{chr(34)};" for x in cpp_fields])}
default:
throw std::runtime_error("Unknown enum value");
}}
}}
inline void parseEnum(std::string_view s, {name}::Tag& t) {{
{chr(10).join([f" if (s == {chr(34)}{x.upper()}{chr(34)}) {{ t = {name}::Tag::{x.upper()}; return; }}" for x in cpp_fields])}
throw std::runtime_error("Unknown enum value: " + std::string{{s}});
}}
"""
cpp_type_decls.append(f"class {name};")
thrift_type_defs[name] = f"""
union {name} {{
{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())}
}}"""
for name in dir(schema):
if name.startswith("_"):
continue
value = getattr(schema, name)
if hasattr(value, "__module__") and value.__module__ != schema.__name__:
continue
defs[name] = value
class_ordering = {}
for name, value in defs.items():
if isinstance(value, type):
if issubclass(value, IntEnum):
_handle_int_enum(name, value)
elif dataclasses.is_dataclass(value):
class_ordering[name] = inspect.findsource(value)[1]
if issubclass(value, _Union):
_handle_union(name, value)
else:
_handle_struct(name, value)
else:
raise AssertionError(f"Unknown schema type {name}: {value}")
elif isinstance(value, (int, tuple)):
assert name in ("SCHEMA_VERSION", "TREESPEC_VERSION")
else:
raise AssertionError(f"Unknown variable {name}: {value}")
yaml_ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"])
assert all(x > 0 for x in yaml_ret["SCHEMA_VERSION"])
yaml_ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"]
assert yaml_ret["TREESPEC_VERSION"] > 0
cpp_header = f"""
#pragma once
#include <optional>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <variant>
#include <vector>
#include <nlohmann/json.hpp>
#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN
#define NLOHMANN_JSON_NAMESPACE_BEGIN namespace nlohmann {{
#endif
#ifndef NLOHMANN_JSON_NAMESPACE_END
#define NLOHMANN_JSON_NAMESPACE_END }}
#endif
// https://github.com/nlohmann/json/pull/2117
NLOHMANN_JSON_NAMESPACE_BEGIN
template <typename T>
struct adl_serializer<std::optional<T>> {{
static void to_json(json& j, const std::optional<T>& opt) {{
if (opt == std::nullopt) {{
j = nullptr;
}} else {{
j = *opt; // this will call adl_serializer<T>::to_json which will
// find the free function to_json in T's namespace!
}}
}}
static void from_json(const json& j, std::optional<T>& opt) {{
if (j.is_null()) {{
opt = std::nullopt;
}} else {{
opt = j.template get<T>(); // same as above, but with
// adl_serializer<T>::from_json
}}
}}
}};
NLOHMANN_JSON_NAMESPACE_END
namespace torch {{
namespace _export {{
template <typename T>
class ForwardRef {{
static_assert(!std::is_reference_v<T>, "ForwardRef cannot be a reference type");
public:
ForwardRef(): ptr_(std::make_unique<T>()) {{}}
ForwardRef(ForwardRef<T>&&);
ForwardRef(const ForwardRef<T>& other): ptr_(std::make_unique<T>(*other.ptr_)) {{}}
ForwardRef<T>& operator=(ForwardRef<T>&&);
ForwardRef<T>& operator=(const ForwardRef<T>& other) {{
ptr_ = std::make_unique<T>(*other.ptr_);
return *this;
}}
~ForwardRef();
const T& operator*() const {{
return *ptr_;
}}
const T* operator->() const {{
return ptr_.get();
}}
void emplace(T&& t) {{
ptr_ = std::make_unique<T>(std::move(t));
}}
private:
std::unique_ptr<T> ptr_;
}};
template <typename T>
void to_json(nlohmann::json& j, const ForwardRef<T>& p) {{
j = *p;
}}
template <typename T>
void from_json(const nlohmann::json& j, ForwardRef<T>& p) {{
p.emplace(j.template get<T>());
}}
class F64 {{
public:
double get() const {{
return value_;
}}
void set(double value) {{
value_ = value;
}}
private:
double value_;
}};
inline void to_json(nlohmann::json& j, const F64& f) {{
if (std::isinf(f.get())) {{
j = "Infinity";
}} else if (std::isinf(-f.get())) {{
j = "-Infinity";
}} else if (std::isnan(f.get())) {{
j = "NaN";
}} else {{
j = f.get();
}}
}}
inline void from_json(const nlohmann::json& j, F64& f) {{
if (j == "Infinity") {{
f.set(std::numeric_limits<double>::infinity());
}} else if (j == "-Infinity") {{
f.set(-std::numeric_limits<double>::infinity());
}} else if (j == "NaN") {{
f.set(std::numeric_limits<double>::quiet_NaN());
}} else {{
f.set(j.get<double>());
}}
}}
{chr(10).join(cpp_type_decls)}
{"".join(cpp_enum_defs.values())}
{"".join(dict(sorted(cpp_class_defs.items(), key=lambda x: class_ordering[x[0]])).values())}
{chr(10).join(cpp_json_defs)}
template <typename T> ForwardRef<T>::ForwardRef(ForwardRef<T>&&) = default;
template <typename T> ForwardRef<T>& ForwardRef<T>::operator=(ForwardRef<T>&&) = default;
template <typename T> ForwardRef<T>::~ForwardRef() = default;
}} // namespace _export
}} // namespace torch
"""
thrift_schema = f"""
namespace py3 torch._export
namespace cpp2 torch._export.schema
{chr(10).join(thrift_enum_defs)}
{chr(10).join(dict(sorted(thrift_type_defs.items(), key=lambda x: class_ordering[x[0]])).values())}
"""
return yaml_ret, cpp_header, thrift_schema
def _diff_schema(dst, src):
additions = {key: src[key] for key in src.keys() - dst.keys()}
subtractions = {key: dst[key] for key in dst.keys() - src.keys()}
common_keys = src.keys() & dst.keys()
versions = {"SCHEMA_VERSION", "TREESPEC_VERSION"}
common_keys -= versions
for key in common_keys:
src_kind = src[key]["kind"]
src_fields = src[key]["fields"]
dst_kind = dst[key]["kind"]
dst_fields = dst[key]["fields"]
_check(
src_kind == dst_kind,
f"Type {key} changed kind from {dst_kind} to {src_kind}",
)
assert isinstance(src_fields, dict) and isinstance(dst_fields, dict)
added_fields = {
key: src_fields[key] for key in src_fields.keys() - dst_fields.keys()
}
subtracted_fields = {
key: dst_fields[key] for key in dst_fields.keys() - src_fields.keys()
}
common_fields = src_fields.keys() & dst_fields.keys()
for field in common_fields:
src_field = src_fields[field]
dst_field = dst_fields[field]
if src_kind == "struct":
_check(
src_field["type"] == dst_field["type"],
f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}",
)
if "default" in src_field and "default" not in dst_field:
added_fields[field] = {}
added_fields[field]["default"] = src_field["default"]
if "default" not in src_field and "default" in dst_field:
subtracted_fields[field] = {}
subtracted_fields[field]["default"] = dst_field["default"]
elif src_kind == "enum":
_check(
src_field == dst_field,
f"Value of the enum field {key}.{field} changed from {dst_field} to {src_field}",
)
elif src_kind == "union":
_check(
src_field["type"] == dst_field["type"],
f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}",
)
else:
raise AssertionError(f"Unknown kind {src_kind}: {key}")
if len(added_fields) > 0:
assert key not in additions
additions[key] = {}
additions[key]["fields"] = added_fields
if len(subtracted_fields) > 0:
assert key not in subtractions
subtractions[key] = {}
subtractions[key]["fields"] = subtracted_fields
return additions, subtractions
def _hash_content(s: str):
return hashlib.sha256(s.strip().encode("utf-8")).hexdigest()
@dataclasses.dataclass
class _Commit:
result: dict[str, Any]
checksum_next: str
yaml_path: str
additions: dict[str, Any]
subtractions: dict[str, Any]
base: dict[str, Any]
checksum_head: Optional[str]
cpp_header: str
cpp_header_path: str
thrift_checksum_head: Optional[str]
thrift_checksum_real: Optional[str]
thrift_checksum_next: str
thrift_schema: str
thrift_schema_path: str
def update_schema():
import importlib.resources
# pyrefly: ignore # bad-argument-type
if importlib.resources.is_resource(__package__, "schema.yaml"):
# pyrefly: ignore # bad-argument-type
content = importlib.resources.read_text(__package__, "schema.yaml")
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content)
_check(match is not None, "checksum not found in schema.yaml")
assert match is not None
checksum_head = match.group(1)
thrift_content = importlib.resources.read_text(
# pyrefly: ignore # bad-argument-type
__package__,
"export_schema.thrift",
)
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", thrift_content)
_check(match is not None, "checksum not found in export_schema.thrift")
assert match is not None
thrift_checksum_head = match.group(1)
thrift_content = thrift_content.splitlines()
assert thrift_content[0].startswith("// @" + "generated")
assert thrift_content[1].startswith("// checksum<<")
thrift_checksum_real = _hash_content("\n".join(thrift_content[2:]))
from yaml import load, Loader
dst = load(content, Loader=Loader)
assert isinstance(dst, dict)
else:
checksum_head = None
thrift_checksum_head = None
thrift_checksum_real = None
dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None}
src, cpp_header, thrift_schema = _staged_schema()
additions, subtractions = _diff_schema(dst, src)
# pyrefly: ignore # missing-attribute
yaml_path = __package__.replace(".", "/") + "/schema.yaml"
# pyrefly: ignore # missing-attribute
thrift_schema_path = __package__.replace(".", "/") + "/export_schema.thrift"
torch_prefix = "torch/"
assert yaml_path.startswith(torch_prefix) # sanity check
assert thrift_schema_path.startswith(torch_prefix) # sanity check
return _Commit(
result=src,
checksum_next=_hash_content(repr(src)),
yaml_path=yaml_path,
additions=additions,
subtractions=subtractions,
base=dst,
checksum_head=checksum_head,
cpp_header=cpp_header,
cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h",
thrift_checksum_head=thrift_checksum_head,
thrift_checksum_real=thrift_checksum_real,
thrift_checksum_next=_hash_content(thrift_schema),
thrift_schema=thrift_schema,
thrift_schema_path=thrift_schema_path,
)
def check(commit: _Commit, force_unsafe: bool = False):
next_version = None
reason = ""
# Step 1: Detect major schema updates.
if len(commit.additions) > 0:
for k, v in commit.additions.items():
if k not in commit.base:
continue
kind = commit.result[k]["kind"]
fields = v["fields"]
for f, d in fields.items():
if kind == "struct" and "default" not in d:
reason += (
f"Field {k}.{f} is added to schema.py without a default value as an incompatible change "
+ "which requires major version bump.\n"
)
next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1]
if len(commit.subtractions) > 0:
for k, v in commit.subtractions.items():
if k not in commit.result:
continue
for f in v["fields"]:
reason = f"Field {k}.{f} is removed from schema.py as an incompatible change which requires major version bump.\n"
next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1]
if force_unsafe:
reason += "--force-unsafe is used."
next_version = commit.result["SCHEMA_VERSION"]
else:
# Step 2: Detect minor schema updates.
if next_version is None and len(commit.additions) > 0:
for k, v in commit.additions.items():
for f in v["fields"]:
reason += (
f"Field {k}.{f} is added to schema.py as an compatible change "
+ "which still requires minor version bump.\n"
)
next_version = [
commit.base["SCHEMA_VERSION"][0],
commit.base["SCHEMA_VERSION"][1] + 1,
]
if next_version is None and len(commit.subtractions) > 0:
for k, v in commit.subtractions.items():
for f in v["fields"]:
reason += (
f"Field {k}.{f} is removed from schema.py as an compatible change "
+ "which still requires minor version bump.\n"
)
next_version = [
commit.base["SCHEMA_VERSION"][0],
commit.base["SCHEMA_VERSION"][1] + 1,
]
return next_version, reason