Files
pytorch/torch/_export/serde/schema.py
Zhengxu Chen 85c807b3fd [export] Ensure optional fields always have default value. (#121163)
Summary: Add additional check to make sure we can always unset an optional field.

Test Plan: CI

Differential Revision: D54504243

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121163
Approved by: https://github.com/tugsbayasgalan
2024-03-05 17:16:49 +00:00

347 lines
7.6 KiB
Python

# NOTE: This is a placeholder for iterating on export serialization schema design.
# Anything is subject to change and no guarantee is provided at this point.
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Dict, List, Optional, Tuple
from torch._export.serde.union import _Union
# NOTE: Please update this value if any modifications are made to the schema
SCHEMA_VERSION = (4, 2)
TREESPEC_VERSION = 1
class ScalarType(IntEnum):
UNKNOWN = 0
BYTE = 1
CHAR = 2
SHORT = 3
INT = 4
LONG = 5
HALF = 6
FLOAT = 7
DOUBLE = 8
COMPLEXHALF = 9
COMPLEXFLOAT = 10
COMPLEXDOUBLE = 11
BOOL = 12
BFLOAT16 = 13
class Layout(IntEnum):
Unknown = 0
SparseCoo = 1
SparseCsr = 2
SparseCsc = 3
SparseBsr = 4
SparseBsc = 5
_mkldnn = 6
Strided = 7
class MemoryFormat(IntEnum):
Unknown = 0
ContiguousFormat = 1
ChannelsLast = 2
ChannelsLast3d = 3
PreserveFormat = 4
@dataclass
class Device:
type: str
index: Optional[int] = None
@dataclass(repr=False)
class SymExprHint(_Union):
as_int: int
as_float: float
as_bool: bool
# This is for storing the symbolic expressions behind symints/symfloats/symbools
# For example, we can get something like
# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4)
# if we also have the hint that s0 and s1 are both 2.
@dataclass
class SymExpr:
expr_str: str
hint: Optional[SymExprHint] = None
@dataclass(repr=False)
class SymInt(_Union):
as_expr: SymExpr
as_int: int
@dataclass(repr=False)
class SymBool(_Union):
as_expr: SymExpr
as_bool: bool
@dataclass
class TensorMeta:
dtype: ScalarType
sizes: List[SymInt]
requires_grad: bool
device: Device
strides: List[SymInt]
storage_offset: int
layout: Layout
# In most cases we will use the "as_name" field to store arguments which are
# SymInts.
# The "as_int" field is used in the case where we have a list containing a mix
# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to
# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints
# to the "as_int" field.
@dataclass(repr=False)
class SymIntArgument(_Union):
as_name: str
as_int: int
# In most cases we will use the "as_name" field to store arguments which are
# SymBools.
# The "as_bool" field is used in the case where we have a list containing a mix
# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to
# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools
# to the "as_bool" field.
@dataclass(repr=False)
class SymBoolArgument(_Union):
as_name: str
as_bool: bool
@dataclass
class TensorArgument:
name: str
# This is use for storing the contents of a list which contain optional tensors
# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the
# type List[OptionalTensorArgument], with tensor values seiralized to the
# "as_tensor" field, and None values serialized to the "as_none" field.
@dataclass(repr=False)
class OptionalTensorArgument(_Union):
as_tensor: str
as_none: Tuple[()]
@dataclass
class GraphArgument:
name: str
graph: 'Graph'
@dataclass
class CustomObjArgument:
name: str
class_fqn: str
# This is actually a union type
@dataclass(repr=False)
class Argument(_Union):
as_none: Tuple[()]
as_tensor: TensorArgument
as_tensors: List[TensorArgument]
as_int: int
as_ints: List[int]
as_float: float
as_floats: List[float]
as_string: str
as_strings: List[str]
as_sym_int: SymIntArgument
as_sym_ints: List[SymIntArgument]
as_scalar_type: ScalarType
as_memory_format: MemoryFormat
as_layout: Layout
as_device: Device
as_bool: bool
as_bools: List[bool]
as_sym_bool: SymBoolArgument
as_sym_bools: List[SymBoolArgument]
as_graph: GraphArgument
as_optional_tensors: List[OptionalTensorArgument]
as_custom_obj: CustomObjArgument
as_operator: str
@dataclass
class NamedArgument:
# Argument name from the operator schema
name: str
arg: Argument
@dataclass
class Node:
target: str
inputs: List[NamedArgument]
outputs: List[Argument]
metadata: Dict[str, str]
@dataclass
class Graph:
inputs: List[Argument]
outputs: List[Argument]
nodes: List[Node]
tensor_values: Dict[str, TensorMeta]
sym_int_values: Dict[str, SymInt]
sym_bool_values: Dict[str, SymBool]
# This is for deserializing the submodule graphs from higher order ops
# (ex. cond, map) where single tensor returns will just return a single
# tensor, rather than following export schema and returning a singleton
# list.
is_single_tensor_return: bool = False
custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
@dataclass
class UserInputSpec:
# Actually, only tensors and SymInts are allowed here
arg: Argument
@dataclass
class InputToParameterSpec:
arg: TensorArgument
parameter_name: str
@dataclass
class InputToBufferSpec:
arg: TensorArgument
buffer_name: str
persistent: bool
@dataclass
class InputToTensorConstantSpec:
arg: TensorArgument
tensor_constant_name: str
@dataclass
class InputToCustomObjSpec:
arg: CustomObjArgument
custom_obj_name: str
@dataclass(repr=False)
class InputSpec(_Union):
user_input: UserInputSpec
parameter: InputToParameterSpec
buffer: InputToBufferSpec
tensor_constant: InputToTensorConstantSpec
custom_obj: InputToCustomObjSpec
@dataclass
class UserOutputSpec:
arg: Argument
@dataclass
class LossOutputSpec:
arg: TensorArgument
@dataclass
class BufferMutationSpec:
arg: TensorArgument
buffer_name: str
@dataclass
class GradientToParameterSpec:
arg: TensorArgument
parameter_name: str
@dataclass
class GradientToUserInputSpec:
arg: TensorArgument
user_input_name: str
@dataclass
class UserInputMutationSpec:
arg: TensorArgument
user_input_name: str
@dataclass(repr=False)
class OutputSpec(_Union):
user_output: UserOutputSpec
loss_output: LossOutputSpec
buffer_mutation: BufferMutationSpec
gradient_to_parameter: GradientToParameterSpec
gradient_to_user_input: GradientToUserInputSpec
user_input_mutation: UserInputMutationSpec
@dataclass
class GraphSignature:
input_specs: List[InputSpec]
output_specs: List[OutputSpec]
@dataclass
class RangeConstraint:
min_val: int
max_val: int
@dataclass
class ModuleCallSignature:
inputs: List[Argument]
outputs: List[Argument]
# These are serialized by calling pytree.treespec_loads
# And deserialized by calling pytree.treespec_dumps
in_spec: str
out_spec: str
@dataclass
class ModuleCallEntry:
fqn: str
signature: Optional[ModuleCallSignature] = None
@dataclass
class GraphModule:
graph: Graph
signature: GraphSignature
# This is used for unflattening, by tracking the calling structure of all of
# the modules in order to unflatten the modules back to the eager calling
# conventions.
module_call_graph: List[ModuleCallEntry]
# Invariant: Every time a change is made to the schema, one of the versions
# should be upadted.
@dataclass
class SchemaVersion:
major: int # Major version number is bumped every time a breaking change is made.
minor: int # Minor version number is bumped when a compatible change is made.
@dataclass
class ExportedProgram:
graph_module: GraphModule
# Key is the opset namespace (ex. aten), and value is the version number
opset_version: Dict[str, int]
range_constraints: Dict[str, RangeConstraint]
schema_version: SchemaVersion
dialect: str