mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] Enforce serialization BC/FC with updater script. (#118424)
Summary: This diff implements a mechanism for safely update torch.export serialization schema, aka schema.py, which is the API surface having the strongest compatibility guarantee. The diff is consist of 3 changes: - Added a script to "build" or "materialize" schema.py into a platform neutral format (yaml), which serves as the committed form of the seialization schema. - Added unittest to compare against schema.py and schema.yaml, so that it forces developers to execute the updater script when there is mismatch between two files. - Added a checker inside the updater script, so that all the compatible change will result in a minor version bump, and all the incompatible changes will result in a major version bump. torch.export's serialization BC/FC policy is (tentatively) documented here: https://docs.google.com/document/d/1EN7JrHbOPDhbpLDtiYG4_BPUs7PttpXlbZ27FuwKhxg/edit#heading=h.pup7ir8rqjhx , we will update the As noted in the code doc, people should be able to run the following command to update schema properly from now on: ``` python scripts/export/update_schema.py --prefix <path_to_torch_development_diretory> or buck run caffe2:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/ ``` Test Plan: buck test mode/opt caffe2/test:test_export -- -r test_schema buck run caffe2:update_export_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/ Differential Revision: D52971020 Pull Request resolved: https://github.com/pytorch/pytorch/pull/118424 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
697ca4f292
commit
2d37a046e7
71
scripts/export/update_schema.py
Normal file
71
scripts/export/update_schema.py
Normal file
@ -0,0 +1,71 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from torch._export.serde import schema_check
|
||||
from yaml import dump, Dumper
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(prog="update_schema")
|
||||
parser.add_argument(
|
||||
"--prefix", type=str, required=True, help="The root of pytorch directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Print the schema instead of writing it to file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-unsafe",
|
||||
action="store_true",
|
||||
help="!!! Only use this option when you are a chad. !!! Force to write the schema even if schema validation doesn't pass.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert os.path.exists(
|
||||
args.prefix
|
||||
), f"Assuming path {args.prefix} is the root of pytorch directory, but it doesn't exist."
|
||||
|
||||
commit = schema_check.update_schema()
|
||||
|
||||
if os.path.exists(args.prefix + commit.path):
|
||||
if commit.result["SCHEMA_VERSION"] < commit.base["SCHEMA_VERSION"]:
|
||||
raise RuntimeError(
|
||||
f"Schema version downgraded from {commit.base['SCHEMA_VERSION']} to {commit.result['SCHEMA_VERSION']}."
|
||||
)
|
||||
|
||||
if commit.result["TREESPEC_VERSION"] < commit.base["TREESPEC_VERSION"]:
|
||||
raise RuntimeError(
|
||||
f"Treespec version downgraded from {commit.base['TREESPEC_VERSION']} to {commit.result['TREESPEC_VERSION']}."
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
args.force_unsafe
|
||||
), "Existing schema yaml file not found, please use --force-unsafe to try again."
|
||||
|
||||
next_version, reason = schema_check.check(commit, args.force_unsafe)
|
||||
|
||||
if next_version is not None and next_version != commit.result["SCHEMA_VERSION"]:
|
||||
raise RuntimeError(
|
||||
f"Schema version is not updated from {commit.base['SCHEMA_VERSION']} to {next_version}.\n"
|
||||
+ f"Please either:\n"
|
||||
+ " 1. update schema.py to not break compatibility.\n"
|
||||
+ " or 2. bump the schema version to the expected value.\n"
|
||||
+ " or 3. use --force-unsafe to override schema.yaml (not recommended).\n "
|
||||
+ "and try again.\n"
|
||||
+ f"Reason: {reason}"
|
||||
)
|
||||
|
||||
header = (
|
||||
"# @" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py"
|
||||
)
|
||||
header += f"\n# checksum<<{commit.checksum_result}>>"
|
||||
payload = dump(commit.result, Dumper=Dumper, sort_keys=False)
|
||||
|
||||
content = header + "\n" + payload
|
||||
|
||||
if args.dry_run:
|
||||
print(content)
|
||||
print("\nWill write the above schema to" + args.prefix + commit.path)
|
||||
else:
|
||||
with open(args.prefix + commit.path, "w") as f:
|
||||
f.write(content)
|
1
setup.py
1
setup.py
@ -1286,6 +1286,7 @@ def main():
|
||||
"include/sleef.h",
|
||||
"_inductor/codegen/*.h",
|
||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||
"_export/serde/*.yaml",
|
||||
"share/cmake/ATen/*.cmake",
|
||||
"share/cmake/Caffe2/*.cmake",
|
||||
"share/cmake/Caffe2/public/*.cmake",
|
||||
|
341
test/export/test_schema.py
Normal file
341
test/export/test_schema.py
Normal file
@ -0,0 +1,341 @@
|
||||
# Owner(s): ["oncall: export"]
|
||||
from torch._export.serde.schema_check import (
|
||||
_Commit,
|
||||
_diff_schema,
|
||||
check,
|
||||
SchemaUpdateError,
|
||||
update_schema,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, run_tests, TestCase
|
||||
|
||||
|
||||
class TestSchema(TestCase):
|
||||
def test_schema_compatibility(self):
|
||||
msg = """
|
||||
Detected an invalidated change to export schema. Please run the following script to update the schema:
|
||||
Example(s):
|
||||
python scripts/export/update_schema.py --prefix <path_to_torch_development_diretory>
|
||||
"""
|
||||
|
||||
if IS_FBCODE:
|
||||
msg += """or
|
||||
buck run caffe2:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/
|
||||
"""
|
||||
try:
|
||||
commit = update_schema()
|
||||
except SchemaUpdateError as e:
|
||||
self.fail(f"Failed to update schema: {e}\n{msg}")
|
||||
|
||||
self.assertEqual(commit.checksum_base, commit.checksum_result, msg)
|
||||
|
||||
def test_schema_diff(self):
|
||||
additions, subtractions = _diff_schema(
|
||||
{
|
||||
"Type0": {"kind": "struct", "fields": {}},
|
||||
"Type2": {
|
||||
"kind": "struct",
|
||||
"fields": {
|
||||
"field0": {"type": ""},
|
||||
"field2": {"type": ""},
|
||||
"field3": {"type": "", "default": "[]"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Type2": {
|
||||
"kind": "struct",
|
||||
"fields": {
|
||||
"field1": {"type": "", "default": "0"},
|
||||
"field2": {"type": "", "default": "[]"},
|
||||
"field3": {"type": ""},
|
||||
},
|
||||
},
|
||||
"Type1": {"kind": "struct", "fields": {}},
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
additions,
|
||||
{
|
||||
"Type1": {"kind": "struct", "fields": {}},
|
||||
"Type2": {
|
||||
"fields": {
|
||||
"field1": {"type": "", "default": "0"},
|
||||
"field2": {"default": "[]"},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
subtractions,
|
||||
{
|
||||
"Type0": {"kind": "struct", "fields": {}},
|
||||
"Type2": {
|
||||
"fields": {
|
||||
"field0": {"type": ""},
|
||||
"field3": {"default": "[]"},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_schema_check(self):
|
||||
# Adding field without default value
|
||||
dst = {
|
||||
"Type2": {
|
||||
"kind": "struct",
|
||||
"fields": {
|
||||
"field0": {"type": ""},
|
||||
},
|
||||
},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
src = {
|
||||
"Type2": {
|
||||
"kind": "struct",
|
||||
"fields": {
|
||||
"field0": {"type": ""},
|
||||
"field1": {"type": ""},
|
||||
},
|
||||
},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
|
||||
additions, subtractions = _diff_schema(dst, src)
|
||||
|
||||
commit = _Commit(
|
||||
result=src,
|
||||
checksum_result="",
|
||||
path="",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base="",
|
||||
)
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [4, 1])
|
||||
|
||||
# Removing field
|
||||
dst = {
|
||||
"Type2": {
|
||||
"kind": "struct",
|
||||
"fields": {
|
||||
"field0": {"type": ""},
|
||||
},
|
||||
},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
src = {
|
||||
"Type2": {
|
||||
"kind": "struct",
|
||||
"fields": {},
|
||||
},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
|
||||
additions, subtractions = _diff_schema(dst, src)
|
||||
|
||||
commit = _Commit(
|
||||
result=src,
|
||||
checksum_result="",
|
||||
path="",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base="",
|
||||
)
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [4, 1])
|
||||
|
||||
# Adding field with default value
|
||||
dst = {
|
||||
"Type2": {
|
||||
"kind": "struct",
|
||||
"fields": {
|
||||
"field0": {"type": ""},
|
||||
},
|
||||
},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
src = {
|
||||
"Type2": {
|
||||
"kind": "struct",
|
||||
"fields": {
|
||||
"field0": {"type": ""},
|
||||
"field1": {"type": "", "default": "[]"},
|
||||
},
|
||||
},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
|
||||
additions, subtractions = _diff_schema(dst, src)
|
||||
|
||||
commit = _Commit(
|
||||
result=src,
|
||||
checksum_result="",
|
||||
path="",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base="",
|
||||
)
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [3, 3])
|
||||
|
||||
# Changing field type
|
||||
dst = {
|
||||
"Type2": {
|
||||
"kind": "struct",
|
||||
"fields": {
|
||||
"field0": {"type": ""},
|
||||
},
|
||||
},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
src = {
|
||||
"Type2": {
|
||||
"kind": "struct",
|
||||
"fields": {
|
||||
"field0": {"type": "int"},
|
||||
},
|
||||
},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
|
||||
with self.assertRaises(SchemaUpdateError):
|
||||
_diff_schema(dst, src)
|
||||
|
||||
# Adding new type.
|
||||
dst = {
|
||||
"Type2": {
|
||||
"kind": "struct",
|
||||
"fields": {
|
||||
"field0": {"type": ""},
|
||||
},
|
||||
},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
src = {
|
||||
"Type2": {
|
||||
"kind": "struct",
|
||||
"fields": {
|
||||
"field0": {"type": ""},
|
||||
},
|
||||
},
|
||||
"Type1": {"kind": "struct", "fields": {}},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
|
||||
additions, subtractions = _diff_schema(dst, src)
|
||||
|
||||
commit = _Commit(
|
||||
result=src,
|
||||
checksum_result="",
|
||||
path="",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base="",
|
||||
)
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [3, 3])
|
||||
|
||||
# Removing a type.
|
||||
dst = {
|
||||
"Type2": {
|
||||
"kind": "struct",
|
||||
"fields": {
|
||||
"field0": {"type": ""},
|
||||
},
|
||||
},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
src = {
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
|
||||
additions, subtractions = _diff_schema(dst, src)
|
||||
|
||||
commit = _Commit(
|
||||
result=src,
|
||||
checksum_result="",
|
||||
path="",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base="",
|
||||
)
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [3, 3])
|
||||
|
||||
# Adding new field in union.
|
||||
dst = {
|
||||
"Type2": {
|
||||
"kind": "union",
|
||||
"fields": {
|
||||
"field0": {"type": ""},
|
||||
},
|
||||
},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
src = {
|
||||
"Type2": {
|
||||
"kind": "union",
|
||||
"fields": {
|
||||
"field0": {"type": ""},
|
||||
"field1": {"type": ""},
|
||||
},
|
||||
},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
|
||||
additions, subtractions = _diff_schema(dst, src)
|
||||
|
||||
commit = _Commit(
|
||||
result=src,
|
||||
checksum_result="",
|
||||
path="",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base="",
|
||||
)
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [3, 3])
|
||||
|
||||
# Removing a field in union.
|
||||
dst = {
|
||||
"Type2": {
|
||||
"kind": "union",
|
||||
"fields": {
|
||||
"field0": {"type": ""},
|
||||
},
|
||||
},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
src = {
|
||||
"Type2": {
|
||||
"kind": "union",
|
||||
"fields": {},
|
||||
},
|
||||
"SCHEMA_VERSION": [3, 2],
|
||||
}
|
||||
|
||||
additions, subtractions = _diff_schema(dst, src)
|
||||
|
||||
commit = _Commit(
|
||||
result=src,
|
||||
checksum_result="",
|
||||
path="",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base="",
|
||||
)
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [4, 1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
373
torch/_export/serde/schema.yaml
Normal file
373
torch/_export/serde/schema.yaml
Normal file
@ -0,0 +1,373 @@
|
||||
# @generated by update_schema.py
|
||||
# checksum<<a6f37692190dca3727e21e6c775ad99a0e440e640753228496f356f583611d40>>
|
||||
Argument:
|
||||
kind: union
|
||||
fields:
|
||||
as_none:
|
||||
type: Tuple[()]
|
||||
as_tensor:
|
||||
type: TensorArgument
|
||||
as_tensors:
|
||||
type: List[TensorArgument]
|
||||
as_int:
|
||||
type: int
|
||||
as_ints:
|
||||
type: List[int]
|
||||
as_float:
|
||||
type: float
|
||||
as_floats:
|
||||
type: List[float]
|
||||
as_string:
|
||||
type: str
|
||||
as_strings:
|
||||
type: List[str]
|
||||
as_sym_int:
|
||||
type: SymIntArgument
|
||||
as_sym_ints:
|
||||
type: List[SymIntArgument]
|
||||
as_scalar_type:
|
||||
type: ScalarType
|
||||
as_memory_format:
|
||||
type: MemoryFormat
|
||||
as_layout:
|
||||
type: Layout
|
||||
as_device:
|
||||
type: Device
|
||||
as_bool:
|
||||
type: bool
|
||||
as_bools:
|
||||
type: List[bool]
|
||||
as_sym_bool:
|
||||
type: SymBoolArgument
|
||||
as_sym_bools:
|
||||
type: List[SymBoolArgument]
|
||||
as_graph:
|
||||
type: GraphArgument
|
||||
as_optional_tensors:
|
||||
type: List[OptionalTensorArgument]
|
||||
as_custom_obj:
|
||||
type: CustomObjArgument
|
||||
BufferMutationSpec:
|
||||
kind: struct
|
||||
fields:
|
||||
arg:
|
||||
type: TensorArgument
|
||||
buffer_name:
|
||||
type: str
|
||||
CustomObjArgument:
|
||||
kind: struct
|
||||
fields:
|
||||
name:
|
||||
type: str
|
||||
class_fqn:
|
||||
type: str
|
||||
Device:
|
||||
kind: struct
|
||||
fields:
|
||||
type:
|
||||
type: str
|
||||
index:
|
||||
type: Optional[int]
|
||||
ExportedProgram:
|
||||
kind: struct
|
||||
fields:
|
||||
graph_module:
|
||||
type: GraphModule
|
||||
opset_version:
|
||||
type: Dict[str, int]
|
||||
range_constraints:
|
||||
type: Dict[str, RangeConstraint]
|
||||
schema_version:
|
||||
type: SchemaVersion
|
||||
dialect:
|
||||
type: str
|
||||
GradientToParameterSpec:
|
||||
kind: struct
|
||||
fields:
|
||||
arg:
|
||||
type: TensorArgument
|
||||
parameter_name:
|
||||
type: str
|
||||
GradientToUserInputSpec:
|
||||
kind: struct
|
||||
fields:
|
||||
arg:
|
||||
type: TensorArgument
|
||||
user_input_name:
|
||||
type: str
|
||||
Graph:
|
||||
kind: struct
|
||||
fields:
|
||||
inputs:
|
||||
type: List[Argument]
|
||||
outputs:
|
||||
type: List[Argument]
|
||||
nodes:
|
||||
type: List[Node]
|
||||
tensor_values:
|
||||
type: Dict[str, TensorMeta]
|
||||
sym_int_values:
|
||||
type: Dict[str, SymInt]
|
||||
sym_bool_values:
|
||||
type: Dict[str, SymBool]
|
||||
is_single_tensor_return:
|
||||
type: bool
|
||||
default: 'False'
|
||||
custom_obj_values:
|
||||
type: Dict[str, CustomObjArgument]
|
||||
default: '{}'
|
||||
GraphArgument:
|
||||
kind: struct
|
||||
fields:
|
||||
name:
|
||||
type: str
|
||||
graph:
|
||||
type: Graph
|
||||
GraphModule:
|
||||
kind: struct
|
||||
fields:
|
||||
graph:
|
||||
type: Graph
|
||||
signature:
|
||||
type: GraphSignature
|
||||
module_call_graph:
|
||||
type: List[ModuleCallEntry]
|
||||
GraphSignature:
|
||||
kind: struct
|
||||
fields:
|
||||
input_specs:
|
||||
type: List[InputSpec]
|
||||
output_specs:
|
||||
type: List[OutputSpec]
|
||||
InputSpec:
|
||||
kind: union
|
||||
fields:
|
||||
user_input:
|
||||
type: UserInputSpec
|
||||
parameter:
|
||||
type: InputToParameterSpec
|
||||
buffer:
|
||||
type: InputToBufferSpec
|
||||
tensor_constant:
|
||||
type: InputToTensorConstantSpec
|
||||
custom_obj:
|
||||
type: InputToCustomObjSpec
|
||||
InputToBufferSpec:
|
||||
kind: struct
|
||||
fields:
|
||||
arg:
|
||||
type: TensorArgument
|
||||
buffer_name:
|
||||
type: str
|
||||
InputToCustomObjSpec:
|
||||
kind: struct
|
||||
fields:
|
||||
arg:
|
||||
type: CustomObjArgument
|
||||
custom_obj_name:
|
||||
type: str
|
||||
InputToParameterSpec:
|
||||
kind: struct
|
||||
fields:
|
||||
arg:
|
||||
type: TensorArgument
|
||||
parameter_name:
|
||||
type: str
|
||||
InputToTensorConstantSpec:
|
||||
kind: struct
|
||||
fields:
|
||||
arg:
|
||||
type: TensorArgument
|
||||
tensor_constant_name:
|
||||
type: str
|
||||
Layout:
|
||||
kind: enum
|
||||
fields:
|
||||
Unknown: 0
|
||||
SparseCoo: 1
|
||||
SparseCsr: 2
|
||||
SparseCsc: 3
|
||||
SparseBsr: 4
|
||||
SparseBsc: 5
|
||||
_mkldnn: 6
|
||||
Strided: 7
|
||||
LossOutputSpec:
|
||||
kind: struct
|
||||
fields:
|
||||
arg:
|
||||
type: TensorArgument
|
||||
MemoryFormat:
|
||||
kind: enum
|
||||
fields:
|
||||
Unknown: 0
|
||||
ContiguousFormat: 1
|
||||
ChannelsLast: 2
|
||||
ChannelsLast3d: 3
|
||||
PreserveFormat: 4
|
||||
ModuleCallEntry:
|
||||
kind: struct
|
||||
fields:
|
||||
fqn:
|
||||
type: str
|
||||
signature:
|
||||
type: Optional[ModuleCallSignature]
|
||||
ModuleCallSignature:
|
||||
kind: struct
|
||||
fields:
|
||||
inputs:
|
||||
type: List[Argument]
|
||||
outputs:
|
||||
type: List[Argument]
|
||||
in_spec:
|
||||
type: str
|
||||
out_spec:
|
||||
type: str
|
||||
NamedArgument:
|
||||
kind: struct
|
||||
fields:
|
||||
name:
|
||||
type: str
|
||||
arg:
|
||||
type: Argument
|
||||
Node:
|
||||
kind: struct
|
||||
fields:
|
||||
target:
|
||||
type: str
|
||||
inputs:
|
||||
type: List[NamedArgument]
|
||||
outputs:
|
||||
type: List[Argument]
|
||||
metadata:
|
||||
type: Dict[str, str]
|
||||
OptionalTensorArgument:
|
||||
kind: union
|
||||
fields:
|
||||
as_tensor:
|
||||
type: str
|
||||
as_none:
|
||||
type: Tuple[()]
|
||||
OutputSpec:
|
||||
kind: union
|
||||
fields:
|
||||
user_output:
|
||||
type: UserOutputSpec
|
||||
loss_output:
|
||||
type: LossOutputSpec
|
||||
buffer_mutation:
|
||||
type: BufferMutationSpec
|
||||
gradient_to_parameter:
|
||||
type: GradientToParameterSpec
|
||||
gradient_to_user_input:
|
||||
type: GradientToUserInputSpec
|
||||
RangeConstraint:
|
||||
kind: struct
|
||||
fields:
|
||||
min_val:
|
||||
type: int
|
||||
max_val:
|
||||
type: int
|
||||
ScalarType:
|
||||
kind: enum
|
||||
fields:
|
||||
UNKNOWN: 0
|
||||
BYTE: 1
|
||||
CHAR: 2
|
||||
SHORT: 3
|
||||
INT: 4
|
||||
LONG: 5
|
||||
HALF: 6
|
||||
FLOAT: 7
|
||||
DOUBLE: 8
|
||||
COMPLEXHALF: 9
|
||||
COMPLEXFLOAT: 10
|
||||
COMPLEXDOUBLE: 11
|
||||
BOOL: 12
|
||||
BFLOAT16: 13
|
||||
SchemaVersion:
|
||||
kind: struct
|
||||
fields:
|
||||
major:
|
||||
type: int
|
||||
minor:
|
||||
type: int
|
||||
SymBool:
|
||||
kind: union
|
||||
fields:
|
||||
as_expr:
|
||||
type: SymExpr
|
||||
as_bool:
|
||||
type: bool
|
||||
SymBoolArgument:
|
||||
kind: union
|
||||
fields:
|
||||
as_name:
|
||||
type: str
|
||||
as_bool:
|
||||
type: bool
|
||||
SymExpr:
|
||||
kind: struct
|
||||
fields:
|
||||
expr_str:
|
||||
type: str
|
||||
hint:
|
||||
type: Optional[SymExprHint]
|
||||
SymExprHint:
|
||||
kind: union
|
||||
fields:
|
||||
as_int:
|
||||
type: int
|
||||
as_float:
|
||||
type: float
|
||||
as_bool:
|
||||
type: bool
|
||||
SymInt:
|
||||
kind: union
|
||||
fields:
|
||||
as_expr:
|
||||
type: SymExpr
|
||||
as_int:
|
||||
type: int
|
||||
SymIntArgument:
|
||||
kind: union
|
||||
fields:
|
||||
as_name:
|
||||
type: str
|
||||
as_int:
|
||||
type: int
|
||||
TensorArgument:
|
||||
kind: struct
|
||||
fields:
|
||||
name:
|
||||
type: str
|
||||
TensorMeta:
|
||||
kind: struct
|
||||
fields:
|
||||
dtype:
|
||||
type: ScalarType
|
||||
sizes:
|
||||
type: List[SymInt]
|
||||
requires_grad:
|
||||
type: bool
|
||||
device:
|
||||
type: Device
|
||||
strides:
|
||||
type: List[SymInt]
|
||||
storage_offset:
|
||||
type: int
|
||||
layout:
|
||||
type: Layout
|
||||
UserInputSpec:
|
||||
kind: struct
|
||||
fields:
|
||||
arg:
|
||||
type: Argument
|
||||
UserOutputSpec:
|
||||
kind: struct
|
||||
fields:
|
||||
arg:
|
||||
type: Argument
|
||||
SCHEMA_VERSION:
|
||||
- 3
|
||||
- 1
|
||||
TREESPEC_VERSION: 1
|
279
torch/_export/serde/schema_check.py
Normal file
279
torch/_export/serde/schema_check.py
Normal file
@ -0,0 +1,279 @@
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import re
|
||||
import typing
|
||||
from enum import IntEnum
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from torch._export.serde import schema
|
||||
from torch._export.serde.union import _Union
|
||||
|
||||
|
||||
class SchemaUpdateError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _check(x, msg):
|
||||
if not x:
|
||||
raise SchemaUpdateError(msg)
|
||||
|
||||
|
||||
def _staged_schema():
|
||||
ret: Dict[str, Any] = {}
|
||||
defs = {}
|
||||
|
||||
def _handle_aggregate(ty):
|
||||
def dump_type(t):
|
||||
if isinstance(t, type):
|
||||
return t.__name__
|
||||
elif isinstance(t, str):
|
||||
assert t in defs
|
||||
return t
|
||||
elif o := typing.get_origin(t):
|
||||
# Lemme know if there's a better way to do this.
|
||||
if o == list:
|
||||
head = "List"
|
||||
elif o == dict:
|
||||
head = "Dict"
|
||||
elif o == tuple:
|
||||
if typing.get_args(t) == ():
|
||||
return "Tuple[()]"
|
||||
head = "Tuple"
|
||||
elif o == Union:
|
||||
args = typing.get_args(t)
|
||||
assert len(args) == 2 and args[1] == type(None)
|
||||
return f"Optional[{dump_type(args[0])}]"
|
||||
else:
|
||||
raise AssertionError(f"Type {t} is not supported in export schema.")
|
||||
return (
|
||||
f"{head}[{', '.join([dump_type(x) for x in typing.get_args(t)])}]"
|
||||
)
|
||||
elif t == ():
|
||||
return "()"
|
||||
else:
|
||||
raise AssertionError(f"Type {t} is not supported in export schema.")
|
||||
|
||||
def dump_field(f):
|
||||
ret = {"type": dump_type(f.type)}
|
||||
|
||||
value = None
|
||||
if f.default is not dataclasses.MISSING:
|
||||
value = f.default
|
||||
elif f.default_factory is not dataclasses.MISSING:
|
||||
value = f.default_factory()
|
||||
|
||||
if value is not None:
|
||||
default = str(value)
|
||||
ret["default"] = default
|
||||
return ret
|
||||
|
||||
return {f.name: dump_field(f) for f in dataclasses.fields(ty)}
|
||||
|
||||
def _handle_int_enum(name, ty):
|
||||
ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}}
|
||||
|
||||
def _handle_struct(name, ty):
|
||||
ret[name] = {"kind": "struct", "fields": _handle_aggregate(ty)}
|
||||
|
||||
def _handle_union(name, ty):
|
||||
ret[name] = {"kind": "union", "fields": _handle_aggregate(ty)}
|
||||
|
||||
for name in dir(schema):
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
value = getattr(schema, name)
|
||||
|
||||
if hasattr(value, "__module__") and value.__module__ != schema.__name__:
|
||||
continue
|
||||
|
||||
defs[name] = value
|
||||
|
||||
for name, value in defs.items():
|
||||
if isinstance(value, type):
|
||||
if issubclass(value, IntEnum):
|
||||
_handle_int_enum(name, value)
|
||||
elif dataclasses.is_dataclass(value):
|
||||
if issubclass(value, _Union):
|
||||
_handle_union(name, value)
|
||||
else:
|
||||
_handle_struct(name, value)
|
||||
else:
|
||||
raise AssertionError(f"Unknown schema type {name}: {value}")
|
||||
elif isinstance(value, (int, tuple)):
|
||||
assert name in ("SCHEMA_VERSION", "TREESPEC_VERSION")
|
||||
else:
|
||||
raise AssertionError(f"Unknown variable {name}: {value}")
|
||||
|
||||
ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"])
|
||||
assert all(x > 0 for x in ret["SCHEMA_VERSION"])
|
||||
ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"]
|
||||
assert ret["TREESPEC_VERSION"] > 0
|
||||
return ret
|
||||
|
||||
|
||||
def _diff_schema(dst, src):
|
||||
additions = {key: src[key] for key in src.keys() - dst.keys()}
|
||||
subtractions = {key: dst[key] for key in dst.keys() - src.keys()}
|
||||
|
||||
common_keys = src.keys() & dst.keys()
|
||||
|
||||
versions = {"SCHEMA_VERSION", "TREESPEC_VERSION"}
|
||||
common_keys -= versions
|
||||
|
||||
for key in common_keys:
|
||||
src_kind = src[key]["kind"]
|
||||
src_fields = src[key]["fields"]
|
||||
dst_kind = dst[key]["kind"]
|
||||
dst_fields = dst[key]["fields"]
|
||||
_check(
|
||||
src_kind == dst_kind,
|
||||
f"Type {key} changed kind from {dst_kind} to {src_kind}",
|
||||
)
|
||||
assert isinstance(src_fields, dict) and isinstance(dst_fields, dict)
|
||||
added_fields = {
|
||||
key: src_fields[key] for key in src_fields.keys() - dst_fields.keys()
|
||||
}
|
||||
subtracted_fields = {
|
||||
key: dst_fields[key] for key in dst_fields.keys() - src_fields.keys()
|
||||
}
|
||||
common_fields = src_fields.keys() & dst_fields.keys()
|
||||
|
||||
for field in common_fields:
|
||||
src_field = src_fields[field]
|
||||
dst_field = dst_fields[field]
|
||||
if src_kind == "struct":
|
||||
_check(
|
||||
src_field["type"] == dst_field["type"],
|
||||
f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}",
|
||||
)
|
||||
if "default" in src_field and "default" not in dst_field:
|
||||
added_fields[field] = {}
|
||||
added_fields[field]["default"] = src_field["default"]
|
||||
if "default" not in src_field and "default" in dst_field:
|
||||
subtracted_fields[field] = {}
|
||||
subtracted_fields[field]["default"] = dst_field["default"]
|
||||
elif src_kind == "enum":
|
||||
_check(
|
||||
src_field == dst_field,
|
||||
f"Value of the enum field {key}.{field} changed from {dst_field} to {src_field}",
|
||||
)
|
||||
elif src_kind == "union":
|
||||
_check(
|
||||
src_field["type"] == dst_field["type"],
|
||||
f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}",
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unknown kind {src_kind}: {key}")
|
||||
if len(added_fields) > 0:
|
||||
assert key not in additions
|
||||
additions[key] = {}
|
||||
additions[key]["fields"] = added_fields
|
||||
if len(subtracted_fields) > 0:
|
||||
assert key not in subtractions
|
||||
subtractions[key] = {}
|
||||
subtractions[key]["fields"] = subtracted_fields
|
||||
|
||||
return additions, subtractions
|
||||
|
||||
|
||||
def _hash_schema(s):
|
||||
return hashlib.sha256(repr(s).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Commit:
|
||||
result: Dict[str, Any]
|
||||
checksum_result: str
|
||||
path: str
|
||||
additions: Dict[str, Any]
|
||||
subtractions: Dict[str, Any]
|
||||
base: Dict[str, Any]
|
||||
checksum_base: Optional[str]
|
||||
|
||||
|
||||
def update_schema():
|
||||
import importlib.resources
|
||||
|
||||
if importlib.resources.is_resource(__package__, "schema.yaml"):
|
||||
content = importlib.resources.read_text(__package__, "schema.yaml")
|
||||
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content)
|
||||
_check(match is not None, "checksum not found in schema.yaml")
|
||||
assert match is not None
|
||||
checksum_base = match.group(1)
|
||||
from yaml import load, Loader
|
||||
|
||||
dst = load(content, Loader=Loader)
|
||||
assert isinstance(dst, dict)
|
||||
else:
|
||||
checksum_base = None
|
||||
dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None}
|
||||
|
||||
src = _staged_schema()
|
||||
additions, subtractions = _diff_schema(dst, src)
|
||||
return _Commit(
|
||||
result=src,
|
||||
checksum_result=_hash_schema(src),
|
||||
path=__package__.replace(".", "/") + "/schema.yaml",
|
||||
additions=additions,
|
||||
subtractions=subtractions,
|
||||
base=dst,
|
||||
checksum_base=checksum_base,
|
||||
)
|
||||
|
||||
|
||||
def check(commit: _Commit, force_unsafe: bool = False):
|
||||
next_version = None
|
||||
reason = ""
|
||||
# Step 1: Detect major schema updates.
|
||||
if len(commit.additions) > 0:
|
||||
for k, v in commit.additions.items():
|
||||
if k not in commit.base:
|
||||
continue
|
||||
kind = commit.result[k]["kind"]
|
||||
fields = v["fields"]
|
||||
for f, d in fields.items():
|
||||
if "default" not in d and kind == "struct":
|
||||
reason += (
|
||||
f"Field {k}.{f} is added to schema.py without a default value as an incomparible change "
|
||||
+ "which requires major version bump.\n"
|
||||
)
|
||||
next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1]
|
||||
|
||||
if len(commit.subtractions) > 0:
|
||||
for k, v in commit.subtractions.items():
|
||||
if k not in commit.result:
|
||||
continue
|
||||
for f in v["fields"]:
|
||||
reason = f"Field {k}.{f} is removed from schema.py as an incompatible change which requires major version bump.\n"
|
||||
next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1]
|
||||
|
||||
if force_unsafe:
|
||||
reason += "--force-unsafe is used."
|
||||
next_version = commit.result["SCHEMA_VERSION"]
|
||||
else:
|
||||
# Step 2: Detect minor schema updates.
|
||||
if next_version is None and len(commit.additions) > 0:
|
||||
for k, v in commit.additions.items():
|
||||
for f in v["fields"]:
|
||||
reason += (
|
||||
f"Field {k}.{f} is added to schema.py as an compatible change "
|
||||
+ "which still requires minor version bump.\n"
|
||||
)
|
||||
next_version = [
|
||||
commit.base["SCHEMA_VERSION"][0],
|
||||
commit.base["SCHEMA_VERSION"][1] + 1,
|
||||
]
|
||||
if next_version is None and len(commit.subtractions) > 0:
|
||||
for k, v in commit.subtractions.items():
|
||||
for f in v["fields"]:
|
||||
reason += (
|
||||
f"Field {k}.{f} is removed from schema.py as an compatible change "
|
||||
+ "which still requires minor version bump.\n"
|
||||
)
|
||||
next_version = [
|
||||
commit.base["SCHEMA_VERSION"][0],
|
||||
commit.base["SCHEMA_VERSION"][1] + 1,
|
||||
]
|
||||
|
||||
return next_version, reason
|
Reference in New Issue
Block a user