mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] Generate compatible thrift schema out of schema.py (#141611)
Summary: To make sure schema.py and schema.thrift are kept in sync, we use the int keys from thrift and use Python Annotated type to associate fields between thrift and schema.py. Later we will use this association to build a single source of truth between the schemas. Test Plan: CI Differential Revision: D66253157 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141611 Approved by: https://github.com/yiming0416
This commit is contained in:
committed by
PyTorch MergeBot
parent
7dd9b5fc43
commit
a8a570512b
@ -72,6 +72,10 @@ if __name__ == "__main__":
|
||||
|
||||
yaml_content = yaml_header + "\n" + yaml_payload
|
||||
|
||||
thrift_schema = "// " + first_line
|
||||
thrift_schema += "\n// " + checksum
|
||||
thrift_schema += "\n" + commit.thrift_schema
|
||||
|
||||
if args.dry_run:
|
||||
print(yaml_content)
|
||||
print("\nWill write the above schema to" + args.prefix + commit.yaml_path)
|
||||
@ -80,3 +84,5 @@ if __name__ == "__main__":
|
||||
f.write(yaml_content)
|
||||
with open(args.prefix + commit.cpp_header_path, "w") as f:
|
||||
f.write(cpp_header)
|
||||
with open(args.prefix + commit.thrift_schema_path, "w") as f:
|
||||
f.write(thrift_schema)
|
||||
|
@ -113,6 +113,8 @@ Example(s):
|
||||
checksum_base="",
|
||||
cpp_header="",
|
||||
cpp_header_path="",
|
||||
thrift_schema="",
|
||||
thrift_schema_path="",
|
||||
)
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [4, 1])
|
||||
@ -147,6 +149,8 @@ Example(s):
|
||||
checksum_base="",
|
||||
cpp_header="",
|
||||
cpp_header_path="",
|
||||
thrift_schema="",
|
||||
thrift_schema_path="",
|
||||
)
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [4, 1])
|
||||
@ -184,6 +188,8 @@ Example(s):
|
||||
checksum_base="",
|
||||
cpp_header="",
|
||||
cpp_header_path="",
|
||||
thrift_schema="",
|
||||
thrift_schema_path="",
|
||||
)
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [3, 3])
|
||||
@ -244,6 +250,8 @@ Example(s):
|
||||
checksum_base="",
|
||||
cpp_header="",
|
||||
cpp_header_path="",
|
||||
thrift_schema="",
|
||||
thrift_schema_path="",
|
||||
)
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [3, 3])
|
||||
@ -274,6 +282,8 @@ Example(s):
|
||||
checksum_base="",
|
||||
cpp_header="",
|
||||
cpp_header_path="",
|
||||
thrift_schema="",
|
||||
thrift_schema_path="",
|
||||
)
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [3, 3])
|
||||
@ -311,6 +321,8 @@ Example(s):
|
||||
checksum_base="",
|
||||
cpp_header="",
|
||||
cpp_header_path="",
|
||||
thrift_schema="",
|
||||
thrift_schema_path="",
|
||||
)
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [3, 3])
|
||||
@ -345,6 +357,8 @@ Example(s):
|
||||
checksum_base="",
|
||||
cpp_header="",
|
||||
cpp_header_path="",
|
||||
thrift_schema="",
|
||||
thrift_schema_path="",
|
||||
)
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [4, 1])
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Annotated, Dict, List, Optional, Tuple
|
||||
|
||||
from torch._export.serde.union import _Union
|
||||
|
||||
@ -53,15 +53,15 @@ class MemoryFormat(IntEnum):
|
||||
|
||||
@dataclass
|
||||
class Device:
|
||||
type: str
|
||||
index: Optional[int] = None
|
||||
type: Annotated[str, 10]
|
||||
index: Annotated[Optional[int], 20] = None
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class SymExprHint(_Union):
|
||||
as_int: int
|
||||
as_float: float
|
||||
as_bool: bool
|
||||
as_int: Annotated[int, 10]
|
||||
as_float: Annotated[float, 20]
|
||||
as_bool: Annotated[bool, 30]
|
||||
|
||||
|
||||
# This is for storing the symbolic expressions behind symints/symfloats/symbools
|
||||
@ -70,31 +70,31 @@ class SymExprHint(_Union):
|
||||
# if we also have the hint that s0 and s1 are both 2.
|
||||
@dataclass
|
||||
class SymExpr:
|
||||
expr_str: str
|
||||
hint: Optional[SymExprHint] = None
|
||||
expr_str: Annotated[str, 10]
|
||||
hint: Annotated[Optional[SymExprHint], 20] = None
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class SymInt(_Union):
|
||||
as_expr: SymExpr
|
||||
as_int: int
|
||||
as_expr: Annotated[SymExpr, 10]
|
||||
as_int: Annotated[int, 20]
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class SymBool(_Union):
|
||||
as_expr: SymExpr
|
||||
as_bool: bool
|
||||
as_expr: Annotated[SymExpr, 10]
|
||||
as_bool: Annotated[bool, 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorMeta:
|
||||
dtype: ScalarType
|
||||
sizes: List[SymInt]
|
||||
requires_grad: bool
|
||||
device: Device
|
||||
strides: List[SymInt]
|
||||
storage_offset: SymInt
|
||||
layout: Layout
|
||||
dtype: Annotated[ScalarType, 10]
|
||||
sizes: Annotated[List[SymInt], 20]
|
||||
requires_grad: Annotated[bool, 30]
|
||||
device: Annotated[Device, 40]
|
||||
strides: Annotated[List[SymInt], 50]
|
||||
storage_offset: Annotated[SymInt, 60]
|
||||
layout: Annotated[Layout, 70]
|
||||
|
||||
|
||||
# In most cases we will use the "as_name" field to store arguments which are
|
||||
@ -105,8 +105,8 @@ class TensorMeta:
|
||||
# to the "as_int" field.
|
||||
@dataclass(repr=False)
|
||||
class SymIntArgument(_Union):
|
||||
as_name: str
|
||||
as_int: int
|
||||
as_name: Annotated[str, 10]
|
||||
as_int: Annotated[int, 20]
|
||||
|
||||
|
||||
# In most cases we will use the "as_name" field to store arguments which are
|
||||
@ -117,18 +117,18 @@ class SymIntArgument(_Union):
|
||||
# to the "as_bool" field.
|
||||
@dataclass(repr=False)
|
||||
class SymBoolArgument(_Union):
|
||||
as_name: str
|
||||
as_bool: bool
|
||||
as_name: Annotated[str, 10]
|
||||
as_bool: Annotated[bool, 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorArgument:
|
||||
name: str
|
||||
name: Annotated[str, 10]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenArgument:
|
||||
name: str
|
||||
name: Annotated[str, 10]
|
||||
|
||||
|
||||
# This is use for storing the contents of a list which contain optional tensors
|
||||
@ -137,252 +137,252 @@ class TokenArgument:
|
||||
# "as_tensor" field, and None values serialized to the "as_none" field.
|
||||
@dataclass(repr=False)
|
||||
class OptionalTensorArgument(_Union):
|
||||
as_tensor: TensorArgument
|
||||
as_none: Tuple[()]
|
||||
as_tensor: Annotated[TensorArgument, 20]
|
||||
as_none: Annotated[Tuple[()], 10]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphArgument:
|
||||
name: str
|
||||
graph: 'Graph'
|
||||
name: Annotated[str, 10]
|
||||
graph: Annotated['Graph', 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomObjArgument:
|
||||
name: str
|
||||
class_fqn: str
|
||||
name: Annotated[str, 10]
|
||||
class_fqn: Annotated[str, 20]
|
||||
|
||||
|
||||
# This is actually a union type
|
||||
@dataclass(repr=False)
|
||||
class Argument(_Union):
|
||||
as_none: Tuple[()]
|
||||
as_tensor: TensorArgument
|
||||
as_tensors: List[TensorArgument]
|
||||
as_int: int
|
||||
as_ints: List[int]
|
||||
as_float: float
|
||||
as_floats: List[float]
|
||||
as_string: str
|
||||
as_strings: List[str]
|
||||
as_sym_int: SymIntArgument
|
||||
as_sym_ints: List[SymIntArgument]
|
||||
as_scalar_type: ScalarType
|
||||
as_memory_format: MemoryFormat
|
||||
as_layout: Layout
|
||||
as_device: Device
|
||||
as_bool: bool
|
||||
as_bools: List[bool]
|
||||
as_sym_bool: SymBoolArgument
|
||||
as_sym_bools: List[SymBoolArgument]
|
||||
as_graph: GraphArgument
|
||||
as_optional_tensors: List[OptionalTensorArgument]
|
||||
as_custom_obj: CustomObjArgument
|
||||
as_operator: str
|
||||
as_none: Annotated[Tuple[()], 10]
|
||||
as_tensor: Annotated[TensorArgument, 20]
|
||||
as_tensors: Annotated[List[TensorArgument], 30]
|
||||
as_int: Annotated[int, 50]
|
||||
as_ints: Annotated[List[int], 70]
|
||||
as_float: Annotated[float, 80]
|
||||
as_floats: Annotated[List[float], 90]
|
||||
as_string: Annotated[str, 100]
|
||||
as_strings: Annotated[List[str], 101]
|
||||
as_sym_int: Annotated[SymIntArgument, 110]
|
||||
as_sym_ints: Annotated[List[SymIntArgument], 120]
|
||||
as_scalar_type: Annotated[ScalarType, 130]
|
||||
as_memory_format: Annotated[MemoryFormat, 140]
|
||||
as_layout: Annotated[Layout, 150]
|
||||
as_device: Annotated[Device, 160]
|
||||
as_bool: Annotated[bool, 170]
|
||||
as_bools: Annotated[List[bool], 180]
|
||||
as_sym_bool: Annotated[SymBoolArgument, 182]
|
||||
as_sym_bools: Annotated[List[SymBoolArgument], 184]
|
||||
as_graph: Annotated[GraphArgument, 200]
|
||||
as_optional_tensors: Annotated[List[OptionalTensorArgument], 190]
|
||||
as_custom_obj: Annotated[CustomObjArgument, 210]
|
||||
as_operator: Annotated[str, 220]
|
||||
|
||||
|
||||
@dataclass
|
||||
class NamedArgument:
|
||||
# Argument name from the operator schema
|
||||
name: str
|
||||
arg: Argument
|
||||
name: Annotated[str, 10]
|
||||
arg: Annotated[Argument, 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Node:
|
||||
target: str
|
||||
inputs: List[NamedArgument]
|
||||
outputs: List[Argument]
|
||||
metadata: Dict[str, str]
|
||||
target: Annotated[str, 10]
|
||||
inputs: Annotated[List[NamedArgument], 20]
|
||||
outputs: Annotated[List[Argument], 30]
|
||||
metadata: Annotated[Dict[str, str], 40]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Graph:
|
||||
inputs: List[Argument]
|
||||
outputs: List[Argument]
|
||||
nodes: List[Node]
|
||||
tensor_values: Dict[str, TensorMeta]
|
||||
sym_int_values: Dict[str, SymInt]
|
||||
sym_bool_values: Dict[str, SymBool]
|
||||
inputs: Annotated[List[Argument], 10]
|
||||
outputs: Annotated[List[Argument], 20]
|
||||
nodes: Annotated[List[Node], 30]
|
||||
tensor_values: Annotated[Dict[str, TensorMeta], 40]
|
||||
sym_int_values: Annotated[Dict[str, SymInt], 50]
|
||||
sym_bool_values: Annotated[Dict[str, SymBool], 60]
|
||||
# This is for deserializing the submodule graphs from higher order ops
|
||||
# (ex. cond, map) where single tensor returns will just return a single
|
||||
# tensor, rather than following export schema and returning a singleton
|
||||
# list.
|
||||
is_single_tensor_return: bool = False
|
||||
custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
|
||||
is_single_tensor_return: Annotated[bool, 70] = False
|
||||
custom_obj_values: Annotated[Dict[str, CustomObjArgument], 80] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInputSpec:
|
||||
# Actually, only tensors and SymInts are allowed here
|
||||
arg: Argument
|
||||
arg: Annotated[Argument, 10]
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class ConstantValue(_Union):
|
||||
as_none: Tuple[()]
|
||||
as_int: int
|
||||
as_float: float
|
||||
as_string: str
|
||||
as_bool: bool
|
||||
as_none: Annotated[Tuple[()], 10]
|
||||
as_int: Annotated[int, 20]
|
||||
as_float: Annotated[float, 30]
|
||||
as_string: Annotated[str, 40]
|
||||
as_bool: Annotated[bool, 50]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputToConstantInputSpec:
|
||||
name: str
|
||||
value: ConstantValue
|
||||
name: Annotated[str, 10]
|
||||
value: Annotated[ConstantValue, 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputToParameterSpec:
|
||||
arg: TensorArgument
|
||||
parameter_name: str
|
||||
arg: Annotated[TensorArgument, 10]
|
||||
parameter_name: Annotated[str, 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputToBufferSpec:
|
||||
arg: TensorArgument
|
||||
buffer_name: str
|
||||
persistent: bool
|
||||
arg: Annotated[TensorArgument, 10]
|
||||
buffer_name: Annotated[str, 20]
|
||||
persistent: Annotated[bool, 30]
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputToTensorConstantSpec:
|
||||
arg: TensorArgument
|
||||
tensor_constant_name: str
|
||||
arg: Annotated[TensorArgument, 10]
|
||||
tensor_constant_name: Annotated[str, 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputToCustomObjSpec:
|
||||
arg: CustomObjArgument
|
||||
custom_obj_name: str
|
||||
arg: Annotated[CustomObjArgument, 10]
|
||||
custom_obj_name: Annotated[str, 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputTokenSpec:
|
||||
arg: TokenArgument
|
||||
arg: Annotated[TokenArgument, 10]
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class InputSpec(_Union):
|
||||
user_input: UserInputSpec
|
||||
parameter: InputToParameterSpec
|
||||
buffer: InputToBufferSpec
|
||||
tensor_constant: InputToTensorConstantSpec
|
||||
custom_obj: InputToCustomObjSpec
|
||||
token: InputTokenSpec
|
||||
constant_input: InputToConstantInputSpec
|
||||
user_input: Annotated[UserInputSpec, 10]
|
||||
parameter: Annotated[InputToParameterSpec, 20]
|
||||
buffer: Annotated[InputToBufferSpec, 30]
|
||||
tensor_constant: Annotated[InputToTensorConstantSpec, 40]
|
||||
custom_obj: Annotated[InputToCustomObjSpec, 50]
|
||||
token: Annotated[InputTokenSpec, 70]
|
||||
constant_input: Annotated[InputToConstantInputSpec, 60]
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserOutputSpec:
|
||||
arg: Argument
|
||||
arg: Annotated[Argument, 10]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LossOutputSpec:
|
||||
arg: TensorArgument
|
||||
arg: Annotated[TensorArgument, 10]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BufferMutationSpec:
|
||||
arg: TensorArgument
|
||||
buffer_name: str
|
||||
arg: Annotated[TensorArgument, 10]
|
||||
buffer_name: Annotated[str, 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GradientToParameterSpec:
|
||||
arg: TensorArgument
|
||||
parameter_name: str
|
||||
arg: Annotated[TensorArgument, 10]
|
||||
parameter_name: Annotated[str, 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GradientToUserInputSpec:
|
||||
arg: TensorArgument
|
||||
user_input_name: str
|
||||
arg: Annotated[TensorArgument, 10]
|
||||
user_input_name: Annotated[str, 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInputMutationSpec:
|
||||
arg: TensorArgument
|
||||
user_input_name: str
|
||||
arg: Annotated[TensorArgument, 10]
|
||||
user_input_name: Annotated[str, 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputTokenSpec:
|
||||
arg: TokenArgument
|
||||
arg: Annotated[TokenArgument, 10]
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class OutputSpec(_Union):
|
||||
user_output: UserOutputSpec
|
||||
loss_output: LossOutputSpec
|
||||
buffer_mutation: BufferMutationSpec
|
||||
gradient_to_parameter: GradientToParameterSpec
|
||||
gradient_to_user_input: GradientToUserInputSpec
|
||||
user_input_mutation: UserInputMutationSpec
|
||||
token: OutputTokenSpec
|
||||
user_output: Annotated[UserOutputSpec, 10]
|
||||
loss_output: Annotated[LossOutputSpec, 20]
|
||||
buffer_mutation: Annotated[BufferMutationSpec, 30]
|
||||
gradient_to_parameter: Annotated[GradientToParameterSpec, 40]
|
||||
gradient_to_user_input: Annotated[GradientToUserInputSpec, 50]
|
||||
user_input_mutation: Annotated[UserInputMutationSpec, 60]
|
||||
token: Annotated[OutputTokenSpec, 70]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphSignature:
|
||||
input_specs: List[InputSpec]
|
||||
output_specs: List[OutputSpec]
|
||||
input_specs: Annotated[List[InputSpec], 10]
|
||||
output_specs: Annotated[List[OutputSpec], 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RangeConstraint:
|
||||
min_val: Optional[int]
|
||||
max_val: Optional[int]
|
||||
min_val: Annotated[Optional[int], 10]
|
||||
max_val: Annotated[Optional[int], 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleCallSignature:
|
||||
inputs: List[Argument]
|
||||
outputs: List[Argument]
|
||||
inputs: Annotated[List[Argument], 10]
|
||||
outputs: Annotated[List[Argument], 20]
|
||||
|
||||
# These are serialized by calling pytree.treespec_loads
|
||||
# And deserialized by calling pytree.treespec_dumps
|
||||
in_spec: str
|
||||
out_spec: str
|
||||
in_spec: Annotated[str, 30]
|
||||
out_spec: Annotated[str, 40]
|
||||
|
||||
# This field is used to prettify the graph placeholders
|
||||
# after we ser/der and retrace
|
||||
forward_arg_names: Optional[List[str]] = None
|
||||
forward_arg_names: Annotated[Optional[List[str]], 50] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleCallEntry:
|
||||
fqn: str
|
||||
signature: Optional[ModuleCallSignature] = None
|
||||
fqn: Annotated[str, 10]
|
||||
signature: Annotated[Optional[ModuleCallSignature], 30] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphModule:
|
||||
graph: Graph
|
||||
signature: GraphSignature
|
||||
graph: Annotated[Graph, 10]
|
||||
signature: Annotated[GraphSignature, 50]
|
||||
# This is used for unflattening, by tracking the calling structure of all of
|
||||
# the modules in order to unflatten the modules back to the eager calling
|
||||
# conventions.
|
||||
module_call_graph: List[ModuleCallEntry]
|
||||
metadata: Dict[str, str] = field(default_factory=dict)
|
||||
module_call_graph: Annotated[List[ModuleCallEntry], 60]
|
||||
metadata: Annotated[Dict[str, str], 40] = field(default_factory=dict)
|
||||
|
||||
|
||||
# Invariant: Every time a change is made to the schema, one of the versions
|
||||
# should be upadted.
|
||||
@dataclass
|
||||
class SchemaVersion:
|
||||
major: int # Major version number is bumped every time a breaking change is made.
|
||||
minor: int # Minor version number is bumped when a compatible change is made.
|
||||
major: Annotated[int, 10] # Major version number is bumped every time a breaking change is made.
|
||||
minor: Annotated[int, 20] # Minor version number is bumped when a compatible change is made.
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExportedProgram:
|
||||
graph_module: GraphModule
|
||||
graph_module: Annotated[GraphModule, 10]
|
||||
# Key is the opset namespace (ex. aten), and value is the version number
|
||||
opset_version: Dict[str, int]
|
||||
range_constraints: Dict[str, RangeConstraint]
|
||||
schema_version: SchemaVersion
|
||||
verifiers: List[str] = field(default_factory=list)
|
||||
torch_version: str = "<=2.4"
|
||||
opset_version: Annotated[Dict[str, int], 20]
|
||||
range_constraints: Annotated[Dict[str, RangeConstraint], 30]
|
||||
schema_version: Annotated[SchemaVersion, 60]
|
||||
verifiers: Annotated[List[str], 70] = field(default_factory=list)
|
||||
torch_version: Annotated[str, 80] = "<=2.4"
|
||||
|
301
torch/_export/serde/schema.thrift
Normal file
301
torch/_export/serde/schema.thrift
Normal file
@ -0,0 +1,301 @@
|
||||
// @generated by update_schema.py
|
||||
// checksum<<19d86105f895a10d5eedbc6e13d4d96cf5d9182c0367d6825ef2438e124cc536>>
|
||||
|
||||
namespace py3 torch._export.schema
|
||||
namespace cpp2 torch._export.schema
|
||||
|
||||
enum Layout {
|
||||
Unknown = 0,
|
||||
SparseCoo = 1,
|
||||
SparseCsr = 2,
|
||||
SparseCsc = 3,
|
||||
SparseBsr = 4,
|
||||
SparseBsc = 5,
|
||||
_mkldnn = 6,
|
||||
Strided = 7,
|
||||
}
|
||||
|
||||
|
||||
enum MemoryFormat {
|
||||
Unknown = 0,
|
||||
ContiguousFormat = 1,
|
||||
ChannelsLast = 2,
|
||||
ChannelsLast3d = 3,
|
||||
PreserveFormat = 4,
|
||||
}
|
||||
|
||||
|
||||
enum ScalarType {
|
||||
UNKNOWN = 0,
|
||||
BYTE = 1,
|
||||
CHAR = 2,
|
||||
SHORT = 3,
|
||||
INT = 4,
|
||||
LONG = 5,
|
||||
HALF = 6,
|
||||
FLOAT = 7,
|
||||
DOUBLE = 8,
|
||||
COMPLEXHALF = 9,
|
||||
COMPLEXFLOAT = 10,
|
||||
COMPLEXDOUBLE = 11,
|
||||
BOOL = 12,
|
||||
BFLOAT16 = 13,
|
||||
UINT16 = 28,
|
||||
}
|
||||
|
||||
|
||||
struct Device {
|
||||
10: string type;
|
||||
20: optional i64 index;
|
||||
}
|
||||
|
||||
union SymExprHint {
|
||||
10: i64 as_int;
|
||||
20: double as_float;
|
||||
30: bool as_bool;
|
||||
}
|
||||
|
||||
struct SymExpr {
|
||||
10: string expr_str;
|
||||
20: optional SymExprHint hint;
|
||||
}
|
||||
|
||||
union SymInt {
|
||||
10: SymExpr as_expr;
|
||||
20: i64 as_int;
|
||||
}
|
||||
|
||||
union SymBool {
|
||||
10: SymExpr as_expr;
|
||||
20: bool as_bool;
|
||||
}
|
||||
|
||||
struct TensorMeta {
|
||||
10: ScalarType dtype;
|
||||
20: list<SymInt> sizes;
|
||||
30: bool requires_grad;
|
||||
40: Device device;
|
||||
50: list<SymInt> strides;
|
||||
60: SymInt storage_offset;
|
||||
70: Layout layout;
|
||||
}
|
||||
|
||||
union SymIntArgument {
|
||||
10: string as_name;
|
||||
20: i64 as_int;
|
||||
}
|
||||
|
||||
union SymBoolArgument {
|
||||
10: string as_name;
|
||||
20: bool as_bool;
|
||||
}
|
||||
|
||||
struct TensorArgument {
|
||||
10: string name;
|
||||
}
|
||||
|
||||
struct TokenArgument {
|
||||
10: string name;
|
||||
}
|
||||
|
||||
union OptionalTensorArgument {
|
||||
20: TensorArgument as_tensor;
|
||||
10: bool as_none;
|
||||
}
|
||||
|
||||
struct GraphArgument {
|
||||
10: string name;
|
||||
20: Graph graph;
|
||||
}
|
||||
|
||||
struct CustomObjArgument {
|
||||
10: string name;
|
||||
20: string class_fqn;
|
||||
}
|
||||
|
||||
union Argument {
|
||||
10: bool as_none;
|
||||
20: TensorArgument as_tensor;
|
||||
30: list<TensorArgument> as_tensors;
|
||||
50: i64 as_int;
|
||||
70: list<i64> as_ints;
|
||||
80: double as_float;
|
||||
90: list<double> as_floats;
|
||||
100: string as_string;
|
||||
101: list<string> as_strings;
|
||||
110: SymIntArgument as_sym_int;
|
||||
120: list<SymIntArgument> as_sym_ints;
|
||||
130: ScalarType as_scalar_type;
|
||||
140: MemoryFormat as_memory_format;
|
||||
150: Layout as_layout;
|
||||
160: Device as_device;
|
||||
170: bool as_bool;
|
||||
180: list<bool> as_bools;
|
||||
182: SymBoolArgument as_sym_bool;
|
||||
184: list<SymBoolArgument> as_sym_bools;
|
||||
200: GraphArgument as_graph;
|
||||
190: list<OptionalTensorArgument> as_optional_tensors;
|
||||
210: CustomObjArgument as_custom_obj;
|
||||
220: string as_operator;
|
||||
}
|
||||
|
||||
struct NamedArgument {
|
||||
10: string name;
|
||||
20: Argument arg;
|
||||
}
|
||||
|
||||
struct Node {
|
||||
10: string target;
|
||||
20: list<NamedArgument> inputs;
|
||||
30: list<Argument> outputs;
|
||||
40: map<string, string> metadata;
|
||||
}
|
||||
|
||||
struct Graph {
|
||||
10: list<Argument> inputs;
|
||||
20: list<Argument> outputs;
|
||||
30: list<Node> nodes;
|
||||
40: map<string, TensorMeta> tensor_values;
|
||||
50: map<string, SymInt> sym_int_values;
|
||||
60: map<string, SymBool> sym_bool_values;
|
||||
70: bool is_single_tensor_return;
|
||||
80: map<string, CustomObjArgument> custom_obj_values;
|
||||
}
|
||||
|
||||
struct UserInputSpec {
|
||||
10: Argument arg;
|
||||
}
|
||||
|
||||
union ConstantValue {
|
||||
10: bool as_none;
|
||||
20: i64 as_int;
|
||||
30: double as_float;
|
||||
40: string as_string;
|
||||
50: bool as_bool;
|
||||
}
|
||||
|
||||
struct InputToConstantInputSpec {
|
||||
10: string name;
|
||||
20: ConstantValue value;
|
||||
}
|
||||
|
||||
struct InputToParameterSpec {
|
||||
10: TensorArgument arg;
|
||||
20: string parameter_name;
|
||||
}
|
||||
|
||||
struct InputToBufferSpec {
|
||||
10: TensorArgument arg;
|
||||
20: string buffer_name;
|
||||
30: bool persistent;
|
||||
}
|
||||
|
||||
struct InputToTensorConstantSpec {
|
||||
10: TensorArgument arg;
|
||||
20: string tensor_constant_name;
|
||||
}
|
||||
|
||||
struct InputToCustomObjSpec {
|
||||
10: CustomObjArgument arg;
|
||||
20: string custom_obj_name;
|
||||
}
|
||||
|
||||
struct InputTokenSpec {
|
||||
10: TokenArgument arg;
|
||||
}
|
||||
|
||||
union InputSpec {
|
||||
10: UserInputSpec user_input;
|
||||
20: InputToParameterSpec parameter;
|
||||
30: InputToBufferSpec buffer;
|
||||
40: InputToTensorConstantSpec tensor_constant;
|
||||
50: InputToCustomObjSpec custom_obj;
|
||||
70: InputTokenSpec token;
|
||||
60: InputToConstantInputSpec constant_input;
|
||||
}
|
||||
|
||||
struct UserOutputSpec {
|
||||
10: Argument arg;
|
||||
}
|
||||
|
||||
struct LossOutputSpec {
|
||||
10: TensorArgument arg;
|
||||
}
|
||||
|
||||
struct BufferMutationSpec {
|
||||
10: TensorArgument arg;
|
||||
20: string buffer_name;
|
||||
}
|
||||
|
||||
struct GradientToParameterSpec {
|
||||
10: TensorArgument arg;
|
||||
20: string parameter_name;
|
||||
}
|
||||
|
||||
struct GradientToUserInputSpec {
|
||||
10: TensorArgument arg;
|
||||
20: string user_input_name;
|
||||
}
|
||||
|
||||
struct UserInputMutationSpec {
|
||||
10: TensorArgument arg;
|
||||
20: string user_input_name;
|
||||
}
|
||||
|
||||
struct OutputTokenSpec {
|
||||
10: TokenArgument arg;
|
||||
}
|
||||
|
||||
union OutputSpec {
|
||||
10: UserOutputSpec user_output;
|
||||
20: LossOutputSpec loss_output;
|
||||
30: BufferMutationSpec buffer_mutation;
|
||||
40: GradientToParameterSpec gradient_to_parameter;
|
||||
50: GradientToUserInputSpec gradient_to_user_input;
|
||||
60: UserInputMutationSpec user_input_mutation;
|
||||
70: OutputTokenSpec token;
|
||||
}
|
||||
|
||||
struct GraphSignature {
|
||||
10: list<InputSpec> input_specs;
|
||||
20: list<OutputSpec> output_specs;
|
||||
}
|
||||
|
||||
struct RangeConstraint {
|
||||
10: optional i64 min_val;
|
||||
20: optional i64 max_val;
|
||||
}
|
||||
|
||||
struct ModuleCallSignature {
|
||||
10: list<Argument> inputs;
|
||||
20: list<Argument> outputs;
|
||||
30: string in_spec;
|
||||
40: string out_spec;
|
||||
50: optional list<string> forward_arg_names;
|
||||
}
|
||||
|
||||
struct ModuleCallEntry {
|
||||
10: string fqn;
|
||||
30: optional ModuleCallSignature signature;
|
||||
}
|
||||
|
||||
struct GraphModule {
|
||||
10: Graph graph;
|
||||
50: GraphSignature signature;
|
||||
60: list<ModuleCallEntry> module_call_graph;
|
||||
40: map<string, string> metadata;
|
||||
}
|
||||
|
||||
struct SchemaVersion {
|
||||
10: i64 major;
|
||||
20: i64 minor;
|
||||
}
|
||||
|
||||
struct ExportedProgram {
|
||||
10: GraphModule graph_module;
|
||||
20: map<string, i64> opset_version;
|
||||
30: map<string, RangeConstraint> range_constraints;
|
||||
60: SchemaVersion schema_version;
|
||||
70: list<string> verifiers;
|
||||
80: string torch_version;
|
||||
}
|
@ -5,7 +5,7 @@ import inspect
|
||||
import re
|
||||
import typing
|
||||
from enum import IntEnum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Annotated, Any, Dict, ForwardRef, List, Optional, Tuple, Union
|
||||
|
||||
from torch._export.serde import schema
|
||||
from torch._export.serde.union import _Union
|
||||
@ -27,50 +27,91 @@ def _staged_schema():
|
||||
cpp_class_defs: Dict[str, str] = {}
|
||||
cpp_type_decls: List[str] = []
|
||||
cpp_json_defs: List[str] = []
|
||||
thrift_enum_defs: List[str] = []
|
||||
thrift_type_defs: Dict[str, str] = {}
|
||||
|
||||
def _handle_aggregate(ty) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
def dump_type(t) -> Tuple[str, str]:
|
||||
TYPE_MAP = {
|
||||
def _handle_aggregate(ty) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
def dump_type(t) -> Tuple[str, str, str]:
|
||||
CPP_TYPE_MAP = {
|
||||
str: "std::string",
|
||||
int: "int64_t",
|
||||
float: "double",
|
||||
bool: "bool",
|
||||
}
|
||||
THRIFT_TYPE_MAP = {
|
||||
str: "string",
|
||||
int: "i64",
|
||||
float: "double",
|
||||
bool: "bool",
|
||||
}
|
||||
if isinstance(t, type):
|
||||
if t.__name__ in cpp_enum_defs:
|
||||
return t.__name__, "int64_t"
|
||||
return t.__name__, "int64_t", t.__name__
|
||||
else:
|
||||
return t.__name__, TYPE_MAP.get(t, t.__name__)
|
||||
return (
|
||||
t.__name__,
|
||||
CPP_TYPE_MAP.get(t, t.__name__),
|
||||
THRIFT_TYPE_MAP.get(t, t.__name__),
|
||||
)
|
||||
elif isinstance(t, str):
|
||||
assert t in defs
|
||||
assert t not in cpp_enum_defs
|
||||
assert "[" not in t
|
||||
return t, f"ForwardRef<{t}>"
|
||||
return t, f"ForwardRef<{t}>", t
|
||||
elif isinstance(t, ForwardRef):
|
||||
return (
|
||||
t.__forward_arg__,
|
||||
f"ForwardRef<{t.__forward_arg__}>",
|
||||
t.__forward_arg__,
|
||||
)
|
||||
elif o := typing.get_origin(t):
|
||||
# Lemme know if there's a better way to do this.
|
||||
if o == list:
|
||||
yaml_head, cpp_head = "List", "std::vector"
|
||||
yaml_head, cpp_head, thrift_head, thrift_tail = (
|
||||
"List",
|
||||
"std::vector",
|
||||
"list<",
|
||||
">",
|
||||
)
|
||||
elif o == dict:
|
||||
yaml_head, cpp_head = "Dict", "std::unordered_map"
|
||||
yaml_head, cpp_head, thrift_head, thrift_tail = (
|
||||
"Dict",
|
||||
"std::unordered_map",
|
||||
"map<",
|
||||
">",
|
||||
)
|
||||
elif o == tuple:
|
||||
if typing.get_args(t) == ():
|
||||
return "Tuple[()]", "std::tuple<>"
|
||||
yaml_head, cpp_head = "Tuple", "std::tuple"
|
||||
return "Tuple[()]", "std::tuple<>", "bool"
|
||||
yaml_head, cpp_head, thrift_head, thrift_tail = (
|
||||
"Tuple",
|
||||
"std::tuple",
|
||||
"bool",
|
||||
"",
|
||||
)
|
||||
elif o == Union:
|
||||
args = typing.get_args(t)
|
||||
assert len(args) == 2 and args[1] == type(None)
|
||||
yaml_type, cpp_type = dump_type(args[0])
|
||||
return f"Optional[{yaml_type}]", f"std::optional<{cpp_type}>"
|
||||
yaml_type, cpp_type, thrift_type = dump_type(args[0])
|
||||
return (
|
||||
f"Optional[{yaml_type}]",
|
||||
f"std::optional<{cpp_type}>",
|
||||
f"optional {thrift_type}",
|
||||
)
|
||||
elif o == Annotated:
|
||||
return dump_type(t.__origin__)
|
||||
else:
|
||||
raise AssertionError(f"Type {t} is not supported in export schema.")
|
||||
yaml_arg_types, cpp_arg_types = zip(
|
||||
yaml_arg_types, cpp_arg_types, thrift_arg_types = zip(
|
||||
*[dump_type(x) for x in typing.get_args(t)]
|
||||
)
|
||||
return (f"{yaml_head}[{', '.join(yaml_arg_types)}]"), (
|
||||
f"{cpp_head}<{', '.join(cpp_arg_types)}>"
|
||||
return (
|
||||
(f"{yaml_head}[{', '.join(yaml_arg_types)}]"),
|
||||
(f"{cpp_head}<{', '.join(cpp_arg_types)}>"),
|
||||
f"{thrift_head}{', '.join(thrift_arg_types)}{thrift_tail}",
|
||||
)
|
||||
elif t == ():
|
||||
return "()", ""
|
||||
return "()", "", ""
|
||||
else:
|
||||
raise AssertionError(f"Type {t} is not supported in export schema.")
|
||||
|
||||
@ -94,11 +135,17 @@ def _staged_schema():
|
||||
f"Default value {v} is not supported yet in export schema."
|
||||
)
|
||||
|
||||
def dump_field(f) -> Tuple[Dict[str, Any], str, Optional[str]]:
|
||||
t, cpp = dump_type(f.type)
|
||||
def dump_field(f) -> Tuple[Dict[str, Any], str, Optional[str], str, int]:
|
||||
t, cpp_type, thrift_type = dump_type(f.type)
|
||||
ret = {"type": t}
|
||||
cpp_type = cpp
|
||||
cpp_default: Optional[str] = None
|
||||
assert (
|
||||
typing.get_origin(f.type) == Annotated
|
||||
), f"Field {f.name} must be annotated with an integer id."
|
||||
thrift_id = f.type.__metadata__[0]
|
||||
assert (
|
||||
type(thrift_id) is int
|
||||
), f"Field {f.name} must be annotated with an integer id."
|
||||
|
||||
value = dataclasses.MISSING
|
||||
if f.default is not dataclasses.MISSING:
|
||||
@ -116,15 +163,23 @@ def _staged_schema():
|
||||
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
|
||||
)
|
||||
|
||||
return ret, cpp_type, cpp_default
|
||||
return ret, cpp_type, cpp_default, thrift_type, thrift_id
|
||||
|
||||
yaml_ret = {}
|
||||
cpp_ret = {}
|
||||
thrift_ret = {}
|
||||
thrift_ids = set()
|
||||
for f in dataclasses.fields(ty):
|
||||
yaml_res, cpp_type, cpp_default = dump_field(f)
|
||||
yaml_res, cpp_type, cpp_default, thrift_type, thrift_id = dump_field(f)
|
||||
yaml_ret[f.name] = yaml_res
|
||||
cpp_ret[f.name] = {"cpp_type": cpp_type, "cpp_default": cpp_default}
|
||||
return yaml_ret, cpp_ret
|
||||
thrift_ret[f.name] = {"thrift_type": thrift_type, "thrift_id": thrift_id}
|
||||
if thrift_id in thrift_ids:
|
||||
raise AssertionError(
|
||||
f"Duplicate thrift id {thrift_id} for field {f.name} in {ty.__name__}."
|
||||
)
|
||||
thrift_ids.add(thrift_id)
|
||||
return yaml_ret, cpp_ret, thrift_ret
|
||||
|
||||
def _handle_int_enum(name, ty):
|
||||
yaml_ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}}
|
||||
@ -135,9 +190,16 @@ enum class {name} {{
|
||||
{chr(10).join([f" {x.name} = {x.value}," for x in ty])}
|
||||
}};
|
||||
"""
|
||||
thrift_enum_defs.append(
|
||||
f"""
|
||||
enum {name} {{
|
||||
{chr(10).join([f" {x.name} = {x.value}," for x in ty])}
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
def _handle_struct(name, ty):
|
||||
fields, cpp_fields = _handle_aggregate(ty)
|
||||
fields, cpp_fields, thrift_fields = _handle_aggregate(ty)
|
||||
yaml_ret[name] = {"kind": "struct", "fields": fields}
|
||||
field_decls = "\n".join(
|
||||
f" {f['cpp_type']} {name}{' = ' + f['cpp_default'] if f['cpp_default'] is not None else ''};"
|
||||
@ -189,8 +251,15 @@ class {name} {{
|
||||
cpp_json_defs.append(f"inline {from_json_decl} {from_json_def}")
|
||||
cpp_type_decls.append(f"class {name};")
|
||||
|
||||
thrift_type_defs[
|
||||
name
|
||||
] = f"""
|
||||
struct {name} {{
|
||||
{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())}
|
||||
}}"""
|
||||
|
||||
def _handle_union(name, ty):
|
||||
fields, cpp_fields = _handle_aggregate(ty)
|
||||
fields, cpp_fields, thrift_fields = _handle_aggregate(ty)
|
||||
yaml_ret[name] = {"kind": "union", "fields": fields}
|
||||
|
||||
def accessor(name, ty, idx):
|
||||
@ -253,6 +322,13 @@ class {name} {{
|
||||
"""
|
||||
cpp_type_decls.append(f"class {name};")
|
||||
|
||||
thrift_type_defs[
|
||||
name
|
||||
] = f"""
|
||||
union {name} {{
|
||||
{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())}
|
||||
}}"""
|
||||
|
||||
for name in dir(schema):
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
@ -378,7 +454,13 @@ void from_json(const nlohmann::json& j, ForwardRef<T>& p) {{
|
||||
}} // namespace _export
|
||||
}} // namespace torch
|
||||
"""
|
||||
return yaml_ret, cpp_header
|
||||
thrift_schema = f"""
|
||||
namespace py3 torch._export.schema
|
||||
namespace cpp2 torch._export.schema
|
||||
{chr(10).join(thrift_enum_defs)}
|
||||
{chr(10).join(dict(sorted(thrift_type_defs.items(), key=lambda x: class_ordering[x[0]])).values())}
|
||||
"""
|
||||
return yaml_ret, cpp_header, thrift_schema
|
||||
|
||||
|
||||
def _diff_schema(dst, src):
|
||||
@ -461,6 +543,8 @@ class _Commit:
|
||||
checksum_base: Optional[str]
|
||||
cpp_header: str
|
||||
cpp_header_path: str
|
||||
thrift_schema: str
|
||||
thrift_schema_path: str
|
||||
|
||||
|
||||
def update_schema():
|
||||
@ -480,11 +564,13 @@ def update_schema():
|
||||
checksum_base = None
|
||||
dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None}
|
||||
|
||||
src, cpp_header = _staged_schema()
|
||||
src, cpp_header, thrift_schema = _staged_schema()
|
||||
additions, subtractions = _diff_schema(dst, src)
|
||||
yaml_path = __package__.replace(".", "/") + "/schema.yaml"
|
||||
thrift_schema_path = __package__.replace(".", "/") + "/schema.thrift"
|
||||
torch_prefix = "torch/"
|
||||
assert yaml_path.startswith(torch_prefix) # sanity check
|
||||
assert thrift_schema_path.startswith(torch_prefix) # sanity check
|
||||
|
||||
return _Commit(
|
||||
result=src,
|
||||
@ -496,6 +582,8 @@ def update_schema():
|
||||
checksum_base=checksum_base,
|
||||
cpp_header=cpp_header,
|
||||
cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h",
|
||||
thrift_schema=thrift_schema,
|
||||
thrift_schema_path=thrift_schema_path,
|
||||
)
|
||||
|
||||
|
||||
|
@ -12,14 +12,15 @@ import logging
|
||||
import math
|
||||
import operator
|
||||
import re
|
||||
import typing
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
@ -51,7 +52,6 @@ from torch.utils._sympy.value_ranges import ValueRanges
|
||||
from .schema import ( # type: ignore[attr-defined]
|
||||
Argument,
|
||||
BufferMutationSpec,
|
||||
InputToConstantInputSpec,
|
||||
ConstantValue,
|
||||
CustomObjArgument,
|
||||
Device,
|
||||
@ -64,6 +64,7 @@ from .schema import ( # type: ignore[attr-defined]
|
||||
GraphSignature,
|
||||
InputSpec,
|
||||
InputToBufferSpec,
|
||||
InputToConstantInputSpec,
|
||||
InputToCustomObjSpec,
|
||||
InputTokenSpec,
|
||||
InputToParameterSpec,
|
||||
@ -2407,6 +2408,8 @@ def serialize(
|
||||
|
||||
def _dict_to_dataclass(cls, data):
|
||||
assert not isinstance(cls, str), f"Unresolved class type: '{cls}'."
|
||||
if typing.get_origin(cls) == Annotated:
|
||||
return _dict_to_dataclass(cls.__origin__, data)
|
||||
if typing.get_origin(cls) == typing.Union and type(None) in typing.get_args(cls):
|
||||
if data is None:
|
||||
return None
|
||||
|
Reference in New Issue
Block a user