[export] Generate compatible thrift schema out of schema.py (#141611)

Summary: To make sure schema.py and schema.thrift are kept in sync, we use the int keys from thrift and use Python Annotated type to associate fields between thrift and schema.py. Later we will use this association to build a single source of truth between the schemas.

Test Plan: CI

Differential Revision: D66253157

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141611
Approved by: https://github.com/yiming0416
This commit is contained in:
Zhengxu Chen
2024-11-29 20:09:49 +00:00
committed by PyTorch MergeBot
parent 7dd9b5fc43
commit a8a570512b
6 changed files with 575 additions and 163 deletions

View File

@ -72,6 +72,10 @@ if __name__ == "__main__":
yaml_content = yaml_header + "\n" + yaml_payload
thrift_schema = "// " + first_line
thrift_schema += "\n// " + checksum
thrift_schema += "\n" + commit.thrift_schema
if args.dry_run:
print(yaml_content)
print("\nWill write the above schema to" + args.prefix + commit.yaml_path)
@ -80,3 +84,5 @@ if __name__ == "__main__":
f.write(yaml_content)
with open(args.prefix + commit.cpp_header_path, "w") as f:
f.write(cpp_header)
with open(args.prefix + commit.thrift_schema_path, "w") as f:
f.write(thrift_schema)

View File

@ -113,6 +113,8 @@ Example(s):
checksum_base="",
cpp_header="",
cpp_header_path="",
thrift_schema="",
thrift_schema_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [4, 1])
@ -147,6 +149,8 @@ Example(s):
checksum_base="",
cpp_header="",
cpp_header_path="",
thrift_schema="",
thrift_schema_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [4, 1])
@ -184,6 +188,8 @@ Example(s):
checksum_base="",
cpp_header="",
cpp_header_path="",
thrift_schema="",
thrift_schema_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [3, 3])
@ -244,6 +250,8 @@ Example(s):
checksum_base="",
cpp_header="",
cpp_header_path="",
thrift_schema="",
thrift_schema_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [3, 3])
@ -274,6 +282,8 @@ Example(s):
checksum_base="",
cpp_header="",
cpp_header_path="",
thrift_schema="",
thrift_schema_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [3, 3])
@ -311,6 +321,8 @@ Example(s):
checksum_base="",
cpp_header="",
cpp_header_path="",
thrift_schema="",
thrift_schema_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [3, 3])
@ -345,6 +357,8 @@ Example(s):
checksum_base="",
cpp_header="",
cpp_header_path="",
thrift_schema="",
thrift_schema_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [4, 1])

View File

@ -3,7 +3,7 @@
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Dict, List, Optional, Tuple
from typing import Annotated, Dict, List, Optional, Tuple
from torch._export.serde.union import _Union
@ -53,15 +53,15 @@ class MemoryFormat(IntEnum):
@dataclass
class Device:
type: str
index: Optional[int] = None
type: Annotated[str, 10]
index: Annotated[Optional[int], 20] = None
@dataclass(repr=False)
class SymExprHint(_Union):
as_int: int
as_float: float
as_bool: bool
as_int: Annotated[int, 10]
as_float: Annotated[float, 20]
as_bool: Annotated[bool, 30]
# This is for storing the symbolic expressions behind symints/symfloats/symbools
@ -70,31 +70,31 @@ class SymExprHint(_Union):
# if we also have the hint that s0 and s1 are both 2.
@dataclass
class SymExpr:
expr_str: str
hint: Optional[SymExprHint] = None
expr_str: Annotated[str, 10]
hint: Annotated[Optional[SymExprHint], 20] = None
@dataclass(repr=False)
class SymInt(_Union):
as_expr: SymExpr
as_int: int
as_expr: Annotated[SymExpr, 10]
as_int: Annotated[int, 20]
@dataclass(repr=False)
class SymBool(_Union):
as_expr: SymExpr
as_bool: bool
as_expr: Annotated[SymExpr, 10]
as_bool: Annotated[bool, 20]
@dataclass
class TensorMeta:
dtype: ScalarType
sizes: List[SymInt]
requires_grad: bool
device: Device
strides: List[SymInt]
storage_offset: SymInt
layout: Layout
dtype: Annotated[ScalarType, 10]
sizes: Annotated[List[SymInt], 20]
requires_grad: Annotated[bool, 30]
device: Annotated[Device, 40]
strides: Annotated[List[SymInt], 50]
storage_offset: Annotated[SymInt, 60]
layout: Annotated[Layout, 70]
# In most cases we will use the "as_name" field to store arguments which are
@ -105,8 +105,8 @@ class TensorMeta:
# to the "as_int" field.
@dataclass(repr=False)
class SymIntArgument(_Union):
as_name: str
as_int: int
as_name: Annotated[str, 10]
as_int: Annotated[int, 20]
# In most cases we will use the "as_name" field to store arguments which are
@ -117,18 +117,18 @@ class SymIntArgument(_Union):
# to the "as_bool" field.
@dataclass(repr=False)
class SymBoolArgument(_Union):
as_name: str
as_bool: bool
as_name: Annotated[str, 10]
as_bool: Annotated[bool, 20]
@dataclass
class TensorArgument:
name: str
name: Annotated[str, 10]
@dataclass
class TokenArgument:
name: str
name: Annotated[str, 10]
# This is use for storing the contents of a list which contain optional tensors
@ -137,252 +137,252 @@ class TokenArgument:
# "as_tensor" field, and None values serialized to the "as_none" field.
@dataclass(repr=False)
class OptionalTensorArgument(_Union):
as_tensor: TensorArgument
as_none: Tuple[()]
as_tensor: Annotated[TensorArgument, 20]
as_none: Annotated[Tuple[()], 10]
@dataclass
class GraphArgument:
name: str
graph: 'Graph'
name: Annotated[str, 10]
graph: Annotated['Graph', 20]
@dataclass
class CustomObjArgument:
name: str
class_fqn: str
name: Annotated[str, 10]
class_fqn: Annotated[str, 20]
# This is actually a union type
@dataclass(repr=False)
class Argument(_Union):
as_none: Tuple[()]
as_tensor: TensorArgument
as_tensors: List[TensorArgument]
as_int: int
as_ints: List[int]
as_float: float
as_floats: List[float]
as_string: str
as_strings: List[str]
as_sym_int: SymIntArgument
as_sym_ints: List[SymIntArgument]
as_scalar_type: ScalarType
as_memory_format: MemoryFormat
as_layout: Layout
as_device: Device
as_bool: bool
as_bools: List[bool]
as_sym_bool: SymBoolArgument
as_sym_bools: List[SymBoolArgument]
as_graph: GraphArgument
as_optional_tensors: List[OptionalTensorArgument]
as_custom_obj: CustomObjArgument
as_operator: str
as_none: Annotated[Tuple[()], 10]
as_tensor: Annotated[TensorArgument, 20]
as_tensors: Annotated[List[TensorArgument], 30]
as_int: Annotated[int, 50]
as_ints: Annotated[List[int], 70]
as_float: Annotated[float, 80]
as_floats: Annotated[List[float], 90]
as_string: Annotated[str, 100]
as_strings: Annotated[List[str], 101]
as_sym_int: Annotated[SymIntArgument, 110]
as_sym_ints: Annotated[List[SymIntArgument], 120]
as_scalar_type: Annotated[ScalarType, 130]
as_memory_format: Annotated[MemoryFormat, 140]
as_layout: Annotated[Layout, 150]
as_device: Annotated[Device, 160]
as_bool: Annotated[bool, 170]
as_bools: Annotated[List[bool], 180]
as_sym_bool: Annotated[SymBoolArgument, 182]
as_sym_bools: Annotated[List[SymBoolArgument], 184]
as_graph: Annotated[GraphArgument, 200]
as_optional_tensors: Annotated[List[OptionalTensorArgument], 190]
as_custom_obj: Annotated[CustomObjArgument, 210]
as_operator: Annotated[str, 220]
@dataclass
class NamedArgument:
# Argument name from the operator schema
name: str
arg: Argument
name: Annotated[str, 10]
arg: Annotated[Argument, 20]
@dataclass
class Node:
target: str
inputs: List[NamedArgument]
outputs: List[Argument]
metadata: Dict[str, str]
target: Annotated[str, 10]
inputs: Annotated[List[NamedArgument], 20]
outputs: Annotated[List[Argument], 30]
metadata: Annotated[Dict[str, str], 40]
@dataclass
class Graph:
inputs: List[Argument]
outputs: List[Argument]
nodes: List[Node]
tensor_values: Dict[str, TensorMeta]
sym_int_values: Dict[str, SymInt]
sym_bool_values: Dict[str, SymBool]
inputs: Annotated[List[Argument], 10]
outputs: Annotated[List[Argument], 20]
nodes: Annotated[List[Node], 30]
tensor_values: Annotated[Dict[str, TensorMeta], 40]
sym_int_values: Annotated[Dict[str, SymInt], 50]
sym_bool_values: Annotated[Dict[str, SymBool], 60]
# This is for deserializing the submodule graphs from higher order ops
# (ex. cond, map) where single tensor returns will just return a single
# tensor, rather than following export schema and returning a singleton
# list.
is_single_tensor_return: bool = False
custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
is_single_tensor_return: Annotated[bool, 70] = False
custom_obj_values: Annotated[Dict[str, CustomObjArgument], 80] = field(default_factory=dict)
@dataclass
class UserInputSpec:
# Actually, only tensors and SymInts are allowed here
arg: Argument
arg: Annotated[Argument, 10]
@dataclass(repr=False)
class ConstantValue(_Union):
as_none: Tuple[()]
as_int: int
as_float: float
as_string: str
as_bool: bool
as_none: Annotated[Tuple[()], 10]
as_int: Annotated[int, 20]
as_float: Annotated[float, 30]
as_string: Annotated[str, 40]
as_bool: Annotated[bool, 50]
@dataclass
class InputToConstantInputSpec:
name: str
value: ConstantValue
name: Annotated[str, 10]
value: Annotated[ConstantValue, 20]
@dataclass
class InputToParameterSpec:
arg: TensorArgument
parameter_name: str
arg: Annotated[TensorArgument, 10]
parameter_name: Annotated[str, 20]
@dataclass
class InputToBufferSpec:
arg: TensorArgument
buffer_name: str
persistent: bool
arg: Annotated[TensorArgument, 10]
buffer_name: Annotated[str, 20]
persistent: Annotated[bool, 30]
@dataclass
class InputToTensorConstantSpec:
arg: TensorArgument
tensor_constant_name: str
arg: Annotated[TensorArgument, 10]
tensor_constant_name: Annotated[str, 20]
@dataclass
class InputToCustomObjSpec:
arg: CustomObjArgument
custom_obj_name: str
arg: Annotated[CustomObjArgument, 10]
custom_obj_name: Annotated[str, 20]
@dataclass
class InputTokenSpec:
arg: TokenArgument
arg: Annotated[TokenArgument, 10]
@dataclass(repr=False)
class InputSpec(_Union):
user_input: UserInputSpec
parameter: InputToParameterSpec
buffer: InputToBufferSpec
tensor_constant: InputToTensorConstantSpec
custom_obj: InputToCustomObjSpec
token: InputTokenSpec
constant_input: InputToConstantInputSpec
user_input: Annotated[UserInputSpec, 10]
parameter: Annotated[InputToParameterSpec, 20]
buffer: Annotated[InputToBufferSpec, 30]
tensor_constant: Annotated[InputToTensorConstantSpec, 40]
custom_obj: Annotated[InputToCustomObjSpec, 50]
token: Annotated[InputTokenSpec, 70]
constant_input: Annotated[InputToConstantInputSpec, 60]
@dataclass
class UserOutputSpec:
arg: Argument
arg: Annotated[Argument, 10]
@dataclass
class LossOutputSpec:
arg: TensorArgument
arg: Annotated[TensorArgument, 10]
@dataclass
class BufferMutationSpec:
arg: TensorArgument
buffer_name: str
arg: Annotated[TensorArgument, 10]
buffer_name: Annotated[str, 20]
@dataclass
class GradientToParameterSpec:
arg: TensorArgument
parameter_name: str
arg: Annotated[TensorArgument, 10]
parameter_name: Annotated[str, 20]
@dataclass
class GradientToUserInputSpec:
arg: TensorArgument
user_input_name: str
arg: Annotated[TensorArgument, 10]
user_input_name: Annotated[str, 20]
@dataclass
class UserInputMutationSpec:
arg: TensorArgument
user_input_name: str
arg: Annotated[TensorArgument, 10]
user_input_name: Annotated[str, 20]
@dataclass
class OutputTokenSpec:
arg: TokenArgument
arg: Annotated[TokenArgument, 10]
@dataclass(repr=False)
class OutputSpec(_Union):
user_output: UserOutputSpec
loss_output: LossOutputSpec
buffer_mutation: BufferMutationSpec
gradient_to_parameter: GradientToParameterSpec
gradient_to_user_input: GradientToUserInputSpec
user_input_mutation: UserInputMutationSpec
token: OutputTokenSpec
user_output: Annotated[UserOutputSpec, 10]
loss_output: Annotated[LossOutputSpec, 20]
buffer_mutation: Annotated[BufferMutationSpec, 30]
gradient_to_parameter: Annotated[GradientToParameterSpec, 40]
gradient_to_user_input: Annotated[GradientToUserInputSpec, 50]
user_input_mutation: Annotated[UserInputMutationSpec, 60]
token: Annotated[OutputTokenSpec, 70]
@dataclass
class GraphSignature:
input_specs: List[InputSpec]
output_specs: List[OutputSpec]
input_specs: Annotated[List[InputSpec], 10]
output_specs: Annotated[List[OutputSpec], 20]
@dataclass
class RangeConstraint:
min_val: Optional[int]
max_val: Optional[int]
min_val: Annotated[Optional[int], 10]
max_val: Annotated[Optional[int], 20]
@dataclass
class ModuleCallSignature:
inputs: List[Argument]
outputs: List[Argument]
inputs: Annotated[List[Argument], 10]
outputs: Annotated[List[Argument], 20]
# These are serialized by calling pytree.treespec_loads
# And deserialized by calling pytree.treespec_dumps
in_spec: str
out_spec: str
in_spec: Annotated[str, 30]
out_spec: Annotated[str, 40]
# This field is used to prettify the graph placeholders
# after we ser/der and retrace
forward_arg_names: Optional[List[str]] = None
forward_arg_names: Annotated[Optional[List[str]], 50] = None
@dataclass
class ModuleCallEntry:
fqn: str
signature: Optional[ModuleCallSignature] = None
fqn: Annotated[str, 10]
signature: Annotated[Optional[ModuleCallSignature], 30] = None
@dataclass
class GraphModule:
graph: Graph
signature: GraphSignature
graph: Annotated[Graph, 10]
signature: Annotated[GraphSignature, 50]
# This is used for unflattening, by tracking the calling structure of all of
# the modules in order to unflatten the modules back to the eager calling
# conventions.
module_call_graph: List[ModuleCallEntry]
metadata: Dict[str, str] = field(default_factory=dict)
module_call_graph: Annotated[List[ModuleCallEntry], 60]
metadata: Annotated[Dict[str, str], 40] = field(default_factory=dict)
# Invariant: Every time a change is made to the schema, one of the versions
# should be upadted.
@dataclass
class SchemaVersion:
major: int # Major version number is bumped every time a breaking change is made.
minor: int # Minor version number is bumped when a compatible change is made.
major: Annotated[int, 10] # Major version number is bumped every time a breaking change is made.
minor: Annotated[int, 20] # Minor version number is bumped when a compatible change is made.
@dataclass
class ExportedProgram:
graph_module: GraphModule
graph_module: Annotated[GraphModule, 10]
# Key is the opset namespace (ex. aten), and value is the version number
opset_version: Dict[str, int]
range_constraints: Dict[str, RangeConstraint]
schema_version: SchemaVersion
verifiers: List[str] = field(default_factory=list)
torch_version: str = "<=2.4"
opset_version: Annotated[Dict[str, int], 20]
range_constraints: Annotated[Dict[str, RangeConstraint], 30]
schema_version: Annotated[SchemaVersion, 60]
verifiers: Annotated[List[str], 70] = field(default_factory=list)
torch_version: Annotated[str, 80] = "<=2.4"

View File

@ -0,0 +1,301 @@
// @generated by update_schema.py
// checksum<<19d86105f895a10d5eedbc6e13d4d96cf5d9182c0367d6825ef2438e124cc536>>
namespace py3 torch._export.schema
namespace cpp2 torch._export.schema
enum Layout {
Unknown = 0,
SparseCoo = 1,
SparseCsr = 2,
SparseCsc = 3,
SparseBsr = 4,
SparseBsc = 5,
_mkldnn = 6,
Strided = 7,
}
enum MemoryFormat {
Unknown = 0,
ContiguousFormat = 1,
ChannelsLast = 2,
ChannelsLast3d = 3,
PreserveFormat = 4,
}
enum ScalarType {
UNKNOWN = 0,
BYTE = 1,
CHAR = 2,
SHORT = 3,
INT = 4,
LONG = 5,
HALF = 6,
FLOAT = 7,
DOUBLE = 8,
COMPLEXHALF = 9,
COMPLEXFLOAT = 10,
COMPLEXDOUBLE = 11,
BOOL = 12,
BFLOAT16 = 13,
UINT16 = 28,
}
struct Device {
10: string type;
20: optional i64 index;
}
union SymExprHint {
10: i64 as_int;
20: double as_float;
30: bool as_bool;
}
struct SymExpr {
10: string expr_str;
20: optional SymExprHint hint;
}
union SymInt {
10: SymExpr as_expr;
20: i64 as_int;
}
union SymBool {
10: SymExpr as_expr;
20: bool as_bool;
}
struct TensorMeta {
10: ScalarType dtype;
20: list<SymInt> sizes;
30: bool requires_grad;
40: Device device;
50: list<SymInt> strides;
60: SymInt storage_offset;
70: Layout layout;
}
union SymIntArgument {
10: string as_name;
20: i64 as_int;
}
union SymBoolArgument {
10: string as_name;
20: bool as_bool;
}
struct TensorArgument {
10: string name;
}
struct TokenArgument {
10: string name;
}
union OptionalTensorArgument {
20: TensorArgument as_tensor;
10: bool as_none;
}
struct GraphArgument {
10: string name;
20: Graph graph;
}
struct CustomObjArgument {
10: string name;
20: string class_fqn;
}
union Argument {
10: bool as_none;
20: TensorArgument as_tensor;
30: list<TensorArgument> as_tensors;
50: i64 as_int;
70: list<i64> as_ints;
80: double as_float;
90: list<double> as_floats;
100: string as_string;
101: list<string> as_strings;
110: SymIntArgument as_sym_int;
120: list<SymIntArgument> as_sym_ints;
130: ScalarType as_scalar_type;
140: MemoryFormat as_memory_format;
150: Layout as_layout;
160: Device as_device;
170: bool as_bool;
180: list<bool> as_bools;
182: SymBoolArgument as_sym_bool;
184: list<SymBoolArgument> as_sym_bools;
200: GraphArgument as_graph;
190: list<OptionalTensorArgument> as_optional_tensors;
210: CustomObjArgument as_custom_obj;
220: string as_operator;
}
struct NamedArgument {
10: string name;
20: Argument arg;
}
struct Node {
10: string target;
20: list<NamedArgument> inputs;
30: list<Argument> outputs;
40: map<string, string> metadata;
}
struct Graph {
10: list<Argument> inputs;
20: list<Argument> outputs;
30: list<Node> nodes;
40: map<string, TensorMeta> tensor_values;
50: map<string, SymInt> sym_int_values;
60: map<string, SymBool> sym_bool_values;
70: bool is_single_tensor_return;
80: map<string, CustomObjArgument> custom_obj_values;
}
struct UserInputSpec {
10: Argument arg;
}
union ConstantValue {
10: bool as_none;
20: i64 as_int;
30: double as_float;
40: string as_string;
50: bool as_bool;
}
struct InputToConstantInputSpec {
10: string name;
20: ConstantValue value;
}
struct InputToParameterSpec {
10: TensorArgument arg;
20: string parameter_name;
}
struct InputToBufferSpec {
10: TensorArgument arg;
20: string buffer_name;
30: bool persistent;
}
struct InputToTensorConstantSpec {
10: TensorArgument arg;
20: string tensor_constant_name;
}
struct InputToCustomObjSpec {
10: CustomObjArgument arg;
20: string custom_obj_name;
}
struct InputTokenSpec {
10: TokenArgument arg;
}
union InputSpec {
10: UserInputSpec user_input;
20: InputToParameterSpec parameter;
30: InputToBufferSpec buffer;
40: InputToTensorConstantSpec tensor_constant;
50: InputToCustomObjSpec custom_obj;
70: InputTokenSpec token;
60: InputToConstantInputSpec constant_input;
}
struct UserOutputSpec {
10: Argument arg;
}
struct LossOutputSpec {
10: TensorArgument arg;
}
struct BufferMutationSpec {
10: TensorArgument arg;
20: string buffer_name;
}
struct GradientToParameterSpec {
10: TensorArgument arg;
20: string parameter_name;
}
struct GradientToUserInputSpec {
10: TensorArgument arg;
20: string user_input_name;
}
struct UserInputMutationSpec {
10: TensorArgument arg;
20: string user_input_name;
}
struct OutputTokenSpec {
10: TokenArgument arg;
}
union OutputSpec {
10: UserOutputSpec user_output;
20: LossOutputSpec loss_output;
30: BufferMutationSpec buffer_mutation;
40: GradientToParameterSpec gradient_to_parameter;
50: GradientToUserInputSpec gradient_to_user_input;
60: UserInputMutationSpec user_input_mutation;
70: OutputTokenSpec token;
}
struct GraphSignature {
10: list<InputSpec> input_specs;
20: list<OutputSpec> output_specs;
}
struct RangeConstraint {
10: optional i64 min_val;
20: optional i64 max_val;
}
struct ModuleCallSignature {
10: list<Argument> inputs;
20: list<Argument> outputs;
30: string in_spec;
40: string out_spec;
50: optional list<string> forward_arg_names;
}
struct ModuleCallEntry {
10: string fqn;
30: optional ModuleCallSignature signature;
}
struct GraphModule {
10: Graph graph;
50: GraphSignature signature;
60: list<ModuleCallEntry> module_call_graph;
40: map<string, string> metadata;
}
struct SchemaVersion {
10: i64 major;
20: i64 minor;
}
struct ExportedProgram {
10: GraphModule graph_module;
20: map<string, i64> opset_version;
30: map<string, RangeConstraint> range_constraints;
60: SchemaVersion schema_version;
70: list<string> verifiers;
80: string torch_version;
}

View File

@ -5,7 +5,7 @@ import inspect
import re
import typing
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Annotated, Any, Dict, ForwardRef, List, Optional, Tuple, Union
from torch._export.serde import schema
from torch._export.serde.union import _Union
@ -27,50 +27,91 @@ def _staged_schema():
cpp_class_defs: Dict[str, str] = {}
cpp_type_decls: List[str] = []
cpp_json_defs: List[str] = []
thrift_enum_defs: List[str] = []
thrift_type_defs: Dict[str, str] = {}
def _handle_aggregate(ty) -> Tuple[Dict[str, Any], Dict[str, Any]]:
def dump_type(t) -> Tuple[str, str]:
TYPE_MAP = {
def _handle_aggregate(ty) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
def dump_type(t) -> Tuple[str, str, str]:
CPP_TYPE_MAP = {
str: "std::string",
int: "int64_t",
float: "double",
bool: "bool",
}
THRIFT_TYPE_MAP = {
str: "string",
int: "i64",
float: "double",
bool: "bool",
}
if isinstance(t, type):
if t.__name__ in cpp_enum_defs:
return t.__name__, "int64_t"
return t.__name__, "int64_t", t.__name__
else:
return t.__name__, TYPE_MAP.get(t, t.__name__)
return (
t.__name__,
CPP_TYPE_MAP.get(t, t.__name__),
THRIFT_TYPE_MAP.get(t, t.__name__),
)
elif isinstance(t, str):
assert t in defs
assert t not in cpp_enum_defs
assert "[" not in t
return t, f"ForwardRef<{t}>"
return t, f"ForwardRef<{t}>", t
elif isinstance(t, ForwardRef):
return (
t.__forward_arg__,
f"ForwardRef<{t.__forward_arg__}>",
t.__forward_arg__,
)
elif o := typing.get_origin(t):
# Lemme know if there's a better way to do this.
if o == list:
yaml_head, cpp_head = "List", "std::vector"
yaml_head, cpp_head, thrift_head, thrift_tail = (
"List",
"std::vector",
"list<",
">",
)
elif o == dict:
yaml_head, cpp_head = "Dict", "std::unordered_map"
yaml_head, cpp_head, thrift_head, thrift_tail = (
"Dict",
"std::unordered_map",
"map<",
">",
)
elif o == tuple:
if typing.get_args(t) == ():
return "Tuple[()]", "std::tuple<>"
yaml_head, cpp_head = "Tuple", "std::tuple"
return "Tuple[()]", "std::tuple<>", "bool"
yaml_head, cpp_head, thrift_head, thrift_tail = (
"Tuple",
"std::tuple",
"bool",
"",
)
elif o == Union:
args = typing.get_args(t)
assert len(args) == 2 and args[1] == type(None)
yaml_type, cpp_type = dump_type(args[0])
return f"Optional[{yaml_type}]", f"std::optional<{cpp_type}>"
yaml_type, cpp_type, thrift_type = dump_type(args[0])
return (
f"Optional[{yaml_type}]",
f"std::optional<{cpp_type}>",
f"optional {thrift_type}",
)
elif o == Annotated:
return dump_type(t.__origin__)
else:
raise AssertionError(f"Type {t} is not supported in export schema.")
yaml_arg_types, cpp_arg_types = zip(
yaml_arg_types, cpp_arg_types, thrift_arg_types = zip(
*[dump_type(x) for x in typing.get_args(t)]
)
return (f"{yaml_head}[{', '.join(yaml_arg_types)}]"), (
f"{cpp_head}<{', '.join(cpp_arg_types)}>"
return (
(f"{yaml_head}[{', '.join(yaml_arg_types)}]"),
(f"{cpp_head}<{', '.join(cpp_arg_types)}>"),
f"{thrift_head}{', '.join(thrift_arg_types)}{thrift_tail}",
)
elif t == ():
return "()", ""
return "()", "", ""
else:
raise AssertionError(f"Type {t} is not supported in export schema.")
@ -94,11 +135,17 @@ def _staged_schema():
f"Default value {v} is not supported yet in export schema."
)
def dump_field(f) -> Tuple[Dict[str, Any], str, Optional[str]]:
t, cpp = dump_type(f.type)
def dump_field(f) -> Tuple[Dict[str, Any], str, Optional[str], str, int]:
t, cpp_type, thrift_type = dump_type(f.type)
ret = {"type": t}
cpp_type = cpp
cpp_default: Optional[str] = None
assert (
typing.get_origin(f.type) == Annotated
), f"Field {f.name} must be annotated with an integer id."
thrift_id = f.type.__metadata__[0]
assert (
type(thrift_id) is int
), f"Field {f.name} must be annotated with an integer id."
value = dataclasses.MISSING
if f.default is not dataclasses.MISSING:
@ -116,15 +163,23 @@ def _staged_schema():
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
)
return ret, cpp_type, cpp_default
return ret, cpp_type, cpp_default, thrift_type, thrift_id
yaml_ret = {}
cpp_ret = {}
thrift_ret = {}
thrift_ids = set()
for f in dataclasses.fields(ty):
yaml_res, cpp_type, cpp_default = dump_field(f)
yaml_res, cpp_type, cpp_default, thrift_type, thrift_id = dump_field(f)
yaml_ret[f.name] = yaml_res
cpp_ret[f.name] = {"cpp_type": cpp_type, "cpp_default": cpp_default}
return yaml_ret, cpp_ret
thrift_ret[f.name] = {"thrift_type": thrift_type, "thrift_id": thrift_id}
if thrift_id in thrift_ids:
raise AssertionError(
f"Duplicate thrift id {thrift_id} for field {f.name} in {ty.__name__}."
)
thrift_ids.add(thrift_id)
return yaml_ret, cpp_ret, thrift_ret
def _handle_int_enum(name, ty):
yaml_ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}}
@ -135,9 +190,16 @@ enum class {name} {{
{chr(10).join([f" {x.name} = {x.value}," for x in ty])}
}};
"""
thrift_enum_defs.append(
f"""
enum {name} {{
{chr(10).join([f" {x.name} = {x.value}," for x in ty])}
}}
"""
)
def _handle_struct(name, ty):
fields, cpp_fields = _handle_aggregate(ty)
fields, cpp_fields, thrift_fields = _handle_aggregate(ty)
yaml_ret[name] = {"kind": "struct", "fields": fields}
field_decls = "\n".join(
f" {f['cpp_type']} {name}{' = ' + f['cpp_default'] if f['cpp_default'] is not None else ''};"
@ -189,8 +251,15 @@ class {name} {{
cpp_json_defs.append(f"inline {from_json_decl} {from_json_def}")
cpp_type_decls.append(f"class {name};")
thrift_type_defs[
name
] = f"""
struct {name} {{
{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())}
}}"""
def _handle_union(name, ty):
fields, cpp_fields = _handle_aggregate(ty)
fields, cpp_fields, thrift_fields = _handle_aggregate(ty)
yaml_ret[name] = {"kind": "union", "fields": fields}
def accessor(name, ty, idx):
@ -253,6 +322,13 @@ class {name} {{
"""
cpp_type_decls.append(f"class {name};")
thrift_type_defs[
name
] = f"""
union {name} {{
{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())}
}}"""
for name in dir(schema):
if name.startswith("_"):
continue
@ -378,7 +454,13 @@ void from_json(const nlohmann::json& j, ForwardRef<T>& p) {{
}} // namespace _export
}} // namespace torch
"""
return yaml_ret, cpp_header
thrift_schema = f"""
namespace py3 torch._export.schema
namespace cpp2 torch._export.schema
{chr(10).join(thrift_enum_defs)}
{chr(10).join(dict(sorted(thrift_type_defs.items(), key=lambda x: class_ordering[x[0]])).values())}
"""
return yaml_ret, cpp_header, thrift_schema
def _diff_schema(dst, src):
@ -461,6 +543,8 @@ class _Commit:
checksum_base: Optional[str]
cpp_header: str
cpp_header_path: str
thrift_schema: str
thrift_schema_path: str
def update_schema():
@ -480,11 +564,13 @@ def update_schema():
checksum_base = None
dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None}
src, cpp_header = _staged_schema()
src, cpp_header, thrift_schema = _staged_schema()
additions, subtractions = _diff_schema(dst, src)
yaml_path = __package__.replace(".", "/") + "/schema.yaml"
thrift_schema_path = __package__.replace(".", "/") + "/schema.thrift"
torch_prefix = "torch/"
assert yaml_path.startswith(torch_prefix) # sanity check
assert thrift_schema_path.startswith(torch_prefix) # sanity check
return _Commit(
result=src,
@ -496,6 +582,8 @@ def update_schema():
checksum_base=checksum_base,
cpp_header=cpp_header,
cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h",
thrift_schema=thrift_schema,
thrift_schema_path=thrift_schema_path,
)

View File

@ -12,14 +12,15 @@ import logging
import math
import operator
import re
import typing
import traceback
import typing
from collections import OrderedDict
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import (
Annotated,
Any,
Callable,
cast,
@ -51,7 +52,6 @@ from torch.utils._sympy.value_ranges import ValueRanges
from .schema import ( # type: ignore[attr-defined]
Argument,
BufferMutationSpec,
InputToConstantInputSpec,
ConstantValue,
CustomObjArgument,
Device,
@ -64,6 +64,7 @@ from .schema import ( # type: ignore[attr-defined]
GraphSignature,
InputSpec,
InputToBufferSpec,
InputToConstantInputSpec,
InputToCustomObjSpec,
InputTokenSpec,
InputToParameterSpec,
@ -2407,6 +2408,8 @@ def serialize(
def _dict_to_dataclass(cls, data):
assert not isinstance(cls, str), f"Unresolved class type: '{cls}'."
if typing.get_origin(cls) == Annotated:
return _dict_to_dataclass(cls.__origin__, data)
if typing.get_origin(cls) == typing.Union and type(None) in typing.get_args(cls):
if data is None:
return None