Files
pytorch/torch/_export/serde/serialize.py
Maggie Moss 4ab847bbc7 Pyrefly suppressions 4/n (#164615)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: uncomment lines in the pyrefly.toml file
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1

after:

0 errors (2,753 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615
Approved by: https://github.com/oulgen
2025-10-06 16:14:36 +00:00

3810 lines
150 KiB
Python

# mypy: allow-untyped-defs
import base64
import copy
import copyreg
import dataclasses
import heapq
import inspect
import io
import json
import keyword
import logging
import math
import operator
import traceback
import typing
from collections import namedtuple, OrderedDict
from collections.abc import Callable, Iterable, Iterator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Annotated, Any, cast, final, Optional, Union
import sympy
import torch
import torch.export.exported_program as ep
from torch._export.non_strict_utils import _enable_graph_inputs_of_type_nn_module
from torch._export.verifier import load_verifier
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx._symbolic_trace import _ConstantAttributeType
from torch.fx.experimental import symbolic_shapes
from torch.utils import _pytree as pytree
from torch.utils._pytree import treespec_dumps, treespec_loads
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.symbol import prefix_str, SymT
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._traceback import CapturedTraceback
from torch.utils._triton import has_triton
from ..utils import remove_proxy_from_state_dict
from .schema import ( # type: ignore[attr-defined]
Argument,
ArgumentKind,
BufferMutationSpec,
ComplexValue,
ConstantValue,
CustomObjArgument,
Device,
ExportedProgram,
GradientToParameterSpec,
GradientToUserInputSpec,
Graph,
GraphArgument,
GraphModule,
GraphSignature,
InputSpec,
InputToBufferSpec,
InputToConstantInputSpec,
InputToCustomObjSpec,
InputTokenSpec,
InputToParameterSpec,
InputToTensorConstantSpec,
Layout,
LossOutputSpec,
MemoryFormat,
ModuleCallEntry,
ModuleCallSignature,
NamedArgument,
NamedTupleDef,
Node,
OptionalTensorArgument,
OutputSpec,
OutputTokenSpec,
ParameterMutationSpec,
RangeConstraint,
ScalarType,
SCHEMA_VERSION,
SchemaVersion,
SymBool,
SymBoolArgument,
SymExpr,
SymExprHint,
SymFloat,
SymFloatArgument,
SymInt,
SymIntArgument,
TensorArgument,
TensorMeta,
TokenArgument,
TREESPEC_VERSION,
UserInputMutationSpec,
UserInputSpec,
UserOutputSpec,
)
from .union import _Union
__all__ = [
"serialize",
"GraphModuleSerializer",
"ExportedProgramSerializer",
"GraphModuleDeserializer",
"ExportedProgramDeserializer",
]
log = logging.getLogger(__name__)
class SerializeError(RuntimeError):
pass
def _reverse_map(d: dict[Any, Enum]):
return {v.value: k for k, v in d.items()}
MetaType = Union[
FakeTensor,
int,
torch.SymInt,
float,
torch.SymFloat,
bool,
torch.SymBool,
ep.CustomObjArgument,
]
DEFAULT_PICKLE_PROTOCOL = 2
ST_DELIMITER = ";"
_TORCH_TO_SERIALIZE_DTYPE = {
torch.uint8: ScalarType.BYTE,
torch.int8: ScalarType.CHAR,
torch.uint16: ScalarType.UINT16,
torch.int16: ScalarType.SHORT,
torch.int32: ScalarType.INT,
torch.int64: ScalarType.LONG,
torch.float16: ScalarType.HALF,
torch.float32: ScalarType.FLOAT,
torch.float64: ScalarType.DOUBLE,
torch.complex32: ScalarType.COMPLEXHALF,
torch.complex64: ScalarType.COMPLEXFLOAT,
torch.complex128: ScalarType.COMPLEXDOUBLE,
torch.bool: ScalarType.BOOL,
torch.bfloat16: ScalarType.BFLOAT16,
torch.float8_e4m3fn: ScalarType.FLOAT8E4M3FN,
torch.float8_e5m2: ScalarType.FLOAT8E5M2,
torch.float8_e4m3fnuz: ScalarType.FLOAT8E4M3FNUZ,
torch.float8_e5m2fnuz: ScalarType.FLOAT8E5M2FNUZ,
}
_SERIALIZE_TO_TORCH_DTYPE = _reverse_map(_TORCH_TO_SERIALIZE_DTYPE) # type: ignore[arg-type]
_TORCH_TO_SERIALIZE_LAYOUT = {
torch.sparse_coo: Layout.SparseCoo,
torch.sparse_csr: Layout.SparseCsr,
torch.sparse_csc: Layout.SparseCsc,
torch.sparse_bsr: Layout.SparseBsr,
torch.sparse_bsc: Layout.SparseBsc,
torch._mkldnn: Layout._mkldnn, # type: ignore[attr-defined]
torch.strided: Layout.Strided,
}
_SERIALIZE_TO_TORCH_LAYOUT = _reverse_map(_TORCH_TO_SERIALIZE_LAYOUT) # type: ignore[arg-type]
_TORCH_TO_SERIALIZE_MEMORY_FORMAT = {
torch.contiguous_format: MemoryFormat.ContiguousFormat,
torch.channels_last: MemoryFormat.ChannelsLast,
torch.channels_last_3d: MemoryFormat.ChannelsLast3d,
torch.preserve_format: MemoryFormat.PreserveFormat,
}
_SERIALIZE_TO_TORCH_MEMORY_FORMAT = _reverse_map(_TORCH_TO_SERIALIZE_MEMORY_FORMAT) # type: ignore[arg-type]
_SYM_OPS = {
operator.eq,
operator.ne,
operator.le,
operator.ge,
operator.lt,
operator.gt,
operator.neg,
operator.pos,
operator.and_,
operator.or_,
math.trunc,
torch.sym_not,
operator.mul,
operator.add,
operator.sub,
operator.floordiv,
operator.mod,
operator.pow,
torch.sym_int,
torch.sym_float,
torch.sym_ite,
torch.sym_max,
torch.sym_min,
torch.sym_sqrt,
operator.truediv,
operator.and_,
}
assert not any(isinstance(op, torch._ops.OpOverload) for op in _SYM_OPS)
@dataclass
class SerializedArtifact:
exported_program: bytes
state_dict: bytes
constants: bytes
example_inputs: bytes
@dataclass
class _SerializedProgram:
exported_program: ExportedProgram
state_dict: bytes
constants: bytes
example_inputs: bytes
class LazyMap(dict):
"""
Dictionary class for deferred instantiation of node metadata values.
Purpose is to avoid creation of symbolic-shape tensors before relevant shape guards are parsed.
"""
def __init__(self):
self.map = {}
self.evaluated = set()
def __setitem__(self, k, v):
self.map[k] = v
def __getitem__(self, k):
out = self.map[k]
if k in self.evaluated:
return out
self.evaluated.add(k)
self.map[k] = out()
return self.map[k]
def __repr__(self):
return self.map.__repr__()
def deserialize_device(d: Device) -> torch.device:
if d.index is None:
return torch.device(type=d.type) # type: ignore[call-overload]
return torch.device(type=d.type, index=d.index)
def deserialize_size(sizes: Sequence[SymInt]) -> tuple[int, ...]:
for sym_int_size in sizes:
assert sym_int_size.type == "as_int", (
f"Only as_int is supported, got {sym_int_size.type}"
)
return tuple(sym_int_size.as_int for sym_int_size in sizes)
def deserialize_stride(strides: Sequence[SymInt]) -> tuple[int, ...]:
for sym_int_stride in strides:
assert sym_int_stride.type == "as_int", (
f"Only as_int is supported, got {sym_int_stride.type}"
)
return tuple(sym_int_stride.as_int for sym_int_stride in strides)
def deserialize_scalar_type(st: ScalarType) -> torch.dtype:
return _SERIALIZE_TO_TORCH_DTYPE[st]
def deserialize_storage_offset(offset: SymInt) -> int:
assert offset.type == "as_int", f"Only as_int is supported, got {offset.type}"
return offset.as_int
def _print_sympy(s: Union[torch.SymInt, torch.SymBool, torch.SymFloat, sympy.Expr]):
if isinstance(s, (torch.SymInt, torch.SymBool, torch.SymFloat)):
s = s.node.expr
return sympy.printing.repr.srepr(s)
def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt:
if isinstance(s, (torch.SymInt, sympy.Symbol, int)):
if symbolic_shapes.is_concrete_int(s):
return SymInt.create(as_int=int(s))
else:
assert isinstance(s, (torch.SymInt, sympy.Symbol))
if s.node.hint is None:
return SymInt.create(as_expr=SymExpr(_print_sympy(s)))
else:
return SymInt.create(
as_expr=SymExpr(
_print_sympy(s),
hint=SymExprHint.create(as_int=s.node.hint),
)
)
else:
raise SerializeError(
f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`"
)
def serialize_sym_float(s: Union[float, torch.SymFloat]) -> SymFloat:
if isinstance(s, (torch.SymFloat, sympy.Symbol, float)):
if symbolic_shapes.is_concrete_float(s):
return SymFloat.create(as_float=float(s))
else:
assert isinstance(s, (torch.SymFloat, sympy.Symbol))
if s.node.hint is None:
return SymFloat.create(as_expr=SymExpr(_print_sympy(s)))
else:
return SymFloat.create(
as_expr=SymExpr(
_print_sympy(s),
hint=SymExprHint.create(as_float=s.node.hint),
)
)
else:
raise SerializeError(
f"SymFloat should be either symbol or float, got `{s}` of type `{type(s)}`"
)
def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool:
if isinstance(s, (torch.SymBool, bool)):
if symbolic_shapes.is_concrete_bool(s):
return SymBool.create(as_bool=bool(s))
else:
return SymBool.create(as_expr=SymExpr(expr_str=_print_sympy(s)))
else:
raise SerializeError(
f"SymBool should be either symbol or bool, got `{s}` of type `{type(s)}`"
)
def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta:
"""
Extract a TensorMeta describing `t`.
"""
return TensorMeta(
dtype=_TORCH_TO_SERIALIZE_DTYPE[t.dtype],
sizes=[serialize_sym_int(s) for s in t.shape],
requires_grad=t.requires_grad,
device=Device(type=t.device.type, index=t.device.index),
strides=[serialize_sym_int(s) for s in t.stride()],
storage_offset=serialize_sym_int(t.storage_offset()),
layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout],
)
_CURRENT_DESERIALIZER: Optional["GraphModuleDeserializer"] = None
def _reduce_fake_tensor(fake_tensor: FakeTensor):
is_parameter = isinstance(fake_tensor, torch.nn.Parameter)
tensor_meta = serialize_tensor_meta(fake_tensor)
tensor_meta_bytes = json.dumps(
_dataclass_to_dict(tensor_meta), cls=EnumEncoder
).encode("utf-8")
return _reconstruct_fake_tensor, (tensor_meta_bytes, is_parameter)
def _reconstruct_fake_tensor(
serialized_tensor_meta: bytes, is_parameter: bool
) -> FakeTensor:
# Deserialize the bytes into a TensorMeta
json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8"))
tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta)
# Find the current fake mode
assert _CURRENT_DESERIALIZER is not None, (
"Need access to current deserializer state"
)
fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta)
if is_parameter:
fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment]
# pyrefly: ignore # bad-return
return fake_tensor
def serialize_torch_artifact(
artifact: Optional[Any], pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL
) -> bytes:
if artifact is None:
return b""
assert FakeTensor not in copyreg.dispatch_table, (
"Refusing to stomp on existing FakeTensor reducer"
)
try:
copyreg.pickle(FakeTensor, _reduce_fake_tensor)
buffer = io.BytesIO()
# This is a workaround for backend's tensor deserialization problem:
# unpickleTensor() always create a tensor on the device where it was originally saved
# This behavior is bad for multi-gpu training, as we wish to directly load the tensor
# on the designated device.
# For now, we simply move the tensor to cpu before saving.
# TODO: this should be fixed by deserialization instead.
torch.save(artifact, buffer, pickle_protocol=pickle_protocol)
return buffer.getvalue()
finally:
del copyreg.dispatch_table[FakeTensor]
def deserialize_torch_artifact(
serialized: Union[dict[str, Any], tuple[Any, ...], bytes],
):
if isinstance(serialized, (dict, tuple)):
return serialized
if len(serialized) == 0:
return {}
buffer = io.BytesIO(serialized)
buffer.seek(0)
# weights_only=False as we want to load custom objects here (e.g. ScriptObject)
artifact = torch.load(buffer, weights_only=False)
assert isinstance(artifact, (tuple, dict))
return artifact
def _sympy_int_to_int(val: sympy.Expr, adjust: str) -> Optional[int]:
# Convert simple sympy Integers into concrete int
if val in (sympy.oo, int_oo):
return None
if val in (-sympy.oo, -int_oo):
return None
if isinstance(val, sympy.Integer):
return int(val)
# TODO: Remove this adjustment when Ed gets rid of fractional ranges
log.warning(
"Export constraints cannot be non-integer expressions. Found "
"type %s, and value %s. We will attempt to %s "
"this value.",
type(val),
val,
adjust,
)
if adjust == "floor":
return math.floor(val)
elif adjust == "ceil":
return math.ceil(val)
else:
raise RuntimeError(f"Got invalid adjustment {adjust}")
def _int_to_sympy_int(val: Optional[int], default) -> sympy.Expr:
# Convert concrete int into simple sympy Integers
if val is None:
return default
if val in [-int_oo, int_oo]:
return val
if val == math.inf:
return int_oo
if val == -math.inf:
return -int_oo
return sympy.Integer(val)
def _symbol_index(sym: sympy.Symbol, sym_type: SymT):
return int(str(sym)[len(prefix_str[sym_type]) :])
def serialize_range_constraints(
range_constraints: dict[sympy.Symbol, ValueRanges],
) -> dict[str, RangeConstraint]:
return {
str(k): RangeConstraint(
_sympy_int_to_int(v.lower, "ceil"), # type: ignore[arg-type]
_sympy_int_to_int(v.upper, "floor"), # type: ignore[arg-type]
)
for k, v in range_constraints.items()
}
def _get_schema_from_target(target):
if isinstance(target, torch._ops.OpOverload):
return target._schema
elif type(target) in _serialization_registry:
return _serialization_registry[type(target)].op_schema(target)
raise RuntimeError(f"Cannot find schema for {type(target)}")
@dataclass
class GraphState:
inputs: list[Argument] = field(default_factory=list)
outputs: list[Argument] = field(default_factory=list)
nodes: list[Node] = field(default_factory=list)
tensor_values: dict[str, TensorMeta] = field(default_factory=dict)
sym_int_values: dict[str, SymInt] = field(default_factory=dict)
sym_bool_values: dict[str, SymBool] = field(default_factory=dict)
sym_float_values: dict[str, SymFloat] = field(default_factory=dict)
is_single_tensor_return: bool = False
custom_obj_values: dict[str, CustomObjArgument] = field(default_factory=dict)
class Final(type):
def __new__(metacls, name, bases, classdict):
for b in bases:
if isinstance(b, Final):
raise TypeError(f"type '{b.__name__}' is not an acceptable base type")
return type.__new__(metacls, name, bases, dict(classdict))
def get_triton_kernel_and_cache_entry(node: torch.fx.Node):
assert (
node.target
is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional
)
assert has_triton(), "triton required to serialize triton kernels"
from triton.runtime.autotuner import Autotuner
assert isinstance(node.kwargs["kernel_idx"], int)
kernel = torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.get_kernel(
node.kwargs["kernel_idx"]
)
kNumWarpsDefault = 4
# currently we only support specialization of
# num_warps -- so search for the entry that
# matches the value from the associated kernel
if isinstance(kernel, Autotuner):
assert len(kernel.configs) == 1
num_warps = kernel.configs[0].num_warps
assert kernel.configs[0].num_ctas == 1, (
"serialization only supports num_ctas == 1"
)
kernel = kernel.fn
else:
num_warps = kNumWarpsDefault
if hasattr(kernel, "device_caches"):
caches = kernel.device_caches
assert len(caches.keys()) == 1
cache = next(iter(caches.values()))[0]
elif hasattr(kernel, "cache"):
# old path, still used for cpu triton builds
caches = kernel.cache
assert len(caches.keys()) == 1
cache = next(iter(caches.values()))
else:
raise AssertionError(f"kernel caches not found for kernel {kernel.__name__}")
# can also get num_warps, num_ctas, etc. from here ig
if len(cache.keys()) == 1:
return kernel, next(iter(cache.values()))
else:
for cache_entry in cache.values():
if cache_entry.metadata.num_warps == num_warps:
return kernel, cache_entry
raise AssertionError(
f"couldn't find a kernel cache entry with metadata matching the autotuner configs for kernel {kernel.__name__}"
)
@final
class GraphModuleSerializer(metaclass=Final):
def __init__(
self,
graph_signature: ep.ExportGraphSignature,
module_call_graph: list[ep.ModuleCallEntry],
):
self.graph_state = GraphState()
self.graph_signature = graph_signature
self.module_call_graph = module_call_graph
self.custom_objs: dict[str, torch._C.ScriptObject] = {}
self.duplicate_getitem_nodes: dict[str, str] = {}
self.treespec_namedtuple_fields: dict[str, NamedTupleDef] = {}
@contextmanager
def save_graph_state(self):
saved = self.graph_state
self.graph_state = GraphState()
try:
yield
finally:
self.graph_state = saved
def handle_placeholder(self, node: torch.fx.Node):
assert node.op == "placeholder"
val = node.meta["val"]
log.debug("[handle_placeholder] %s: %s", node.name, val)
if isinstance(val, torch.Tensor):
graph_input = Argument.create(
as_tensor=self.serialize_tensor_output(node.name, val)
)
elif isinstance(val, torch.SymInt):
graph_input = Argument.create(
as_sym_int=self.serialize_sym_int_output(node.name, val)
)
elif isinstance(val, torch.SymFloat):
raise AssertionError("SymFloat graph input is not implemented yet.")
elif isinstance(val, (int, bool, str, float, type(None))):
graph_input = self.serialize_input(val)
elif isinstance(val, ep.CustomObjArgument):
class_fqn = val.class_fqn
graph_input = Argument.create(
as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn)
)
self.graph_state.custom_obj_values[node.name] = (
self.serialize_script_obj_meta(val)
)
else:
raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}")
self.graph_state.inputs.append(graph_input)
def handle_output(self, node: torch.fx.Node):
assert node.op == "output"
assert len(node.args) == 1, "FX.Node's args should have one arg"
node_args = node.args[0]
log.debug("[handle_output] %s: %s", node.name, node_args)
if isinstance(node_args, torch.fx.Node):
# For singleton tensor returns
self.graph_state.is_single_tensor_return = True
self.graph_state.outputs = [self.serialize_input(node_args)]
else:
assert isinstance(node_args, (tuple, list))
self.graph_state.outputs = [self.serialize_input(arg) for arg in node_args]
def serialize_operator(self, target) -> str:
if isinstance(target, str):
return target
elif target.__module__.startswith("torch._ops"):
# TODO(zhxchen17) Maybe provide a function name helper in FX.
# From torch.fx.node._get_qualified_name
module = target.__module__.replace("torch._ops", "torch.ops")
return f"{module}.{target.__name__}"
else: # TODO(zhxchen17) Don't catch all here.
return f"{target.__module__}.{target.__name__}"
def handle_call_function(self, node: torch.fx.Node):
assert node.op == "call_function"
meta_val = node.meta.get("val")
log.debug(
"[handle_call_function] %s: %s(%s, {%s}) -> %s",
node.name,
node.target,
node.args,
node.kwargs,
meta_val,
)
# getitem has been handled in the producer node, skip it here
if node.target is operator.getitem:
return
if node.target in _SYM_OPS or (
meta_val is not None
and isinstance(meta_val, (torch.SymInt, torch.SymBool, torch.SymFloat))
):
assert len(node.kwargs) == 0
ex_node = Node(
target=self.serialize_operator(node.target),
inputs=self.serialize_sym_op_inputs(node.target, node.args),
outputs=[self.serialize_output(node.name, meta_val)],
metadata=self.serialize_metadata(node),
)
elif isinstance(node.target, torch._ops.OpOverload):
ex_node = Node(
target=self.serialize_operator(node.target),
inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
outputs=self.serialize_outputs(node),
# TODO: create a new tensor_values here, meta might have faketensor info
metadata=self.serialize_metadata(node),
)
elif isinstance(node.target, torch._ops.HigherOrderOperator):
def _is_hop_single_tensor_return(node) -> bool:
assert isinstance(node.target, torch._ops.HigherOrderOperator)
# HOP schema is not always available, so we look at node.meta["val"]
meta_val = node.meta.get("val", None)
return meta_val is not None and isinstance(meta_val, torch.Tensor)
# Special handle serialization for aoti_call_delegate
if node.target is torch._higher_order_ops.aoti_call_delegate:
serializable_args = list(node.args)
# AOTI lowered module is not serializable, serialize the aoti_path instead
lowered_module_name: str = node.args[0].name # type: ignore[assignment, no-untyped-def, union-attr]
assert hasattr(node.graph.owning_module, lowered_module_name)
lowered_module = getattr(node.graph.owning_module, lowered_module_name) # type: ignore[no-untyped-def]
serializable_args[0] = lowered_module.aoti_path
# AOTI compiled graph module in node.args[0] is stateful, and will fail the verifier check
# Skip serializing original_gm as a workaround
serializable_args[1] = None
serializable_weight_nodes = []
if serializable_args[2] is not None and isinstance(
serializable_args[2], Iterable
):
for weight_node in serializable_args[2]:
# skip passing custom obj into the weight arg as an hack
# The schema of weight input is a list of Tensors.
# Downstream runtime is not actively consuming the weighs arg for anything meaningful.
if isinstance(weight_node, torch.fx.Node) and isinstance(
weight_node.meta.get("val", None), ep.CustomObjArgument
):
continue
serializable_weight_nodes.append(weight_node)
serializable_args[2] = serializable_weight_nodes
def serialize_tensor_list_output(node):
meta_val = node.meta.get("val", None)
tensor_args = []
for idx, meta in enumerate(meta_val):
name = self._output_node_name_at_index(node, idx)
tensor_args.append(self.serialize_tensor_output(name, meta))
return [Argument.create(as_tensors=tensor_args)]
ex_node = Node(
target=self.serialize_operator(node.target),
inputs=self.serialize_hoo_inputs(serializable_args, node.kwargs),
outputs=serialize_tensor_list_output(node),
metadata=self.serialize_metadata(node),
is_hop_single_tensor_return=False,
)
elif (
node.target
is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional
):
kernel, kernel_cache_entry = get_triton_kernel_and_cache_entry(node)
kernel_cache_metadata = kernel_cache_entry.metadata
meta_val = node.meta["val"]
assert isinstance(meta_val, dict)
output_keys = meta_val.keys()
output_indices = []
constexpr_keys = set()
for p in kernel.params:
if p.is_constexpr:
constexpr_keys.add(p.name)
found_constexpr = False
args_new = ()
i = 0
assert isinstance(node.kwargs["kwargs"], dict)
for k, v in node.kwargs["kwargs"].items():
# don't serialize constexpr since they will
# be embedded into the binary and don't
# need to be passed around as attributes
if k in constexpr_keys:
found_constexpr = True
continue
assert not found_constexpr, (
"non-constexpr args found after constexpr arg(s)"
)
if k in output_keys:
output_indices.append(i)
args_new += (v,) # type: ignore[assignment]
i += 1
assert isinstance(node.kwargs["grid"], list)
kwargs_new = {
"name": kernel.fn.__name__,
"grid": node.kwargs["grid"][0],
"output_indices": output_indices,
"num_warps": kernel_cache_metadata.num_warps,
}
if hasattr(kernel_cache_metadata, "shared"):
kwargs_new["shared_memory_bytes"] = kernel_cache_metadata.shared
ex_node = Node(
target=self.serialize_operator(node.target),
inputs=self.serialize_hoo_inputs(args_new, kwargs_new),
outputs=self.serialize_hoo_outputs(node),
metadata=self.serialize_metadata(node),
is_hop_single_tensor_return=_is_hop_single_tensor_return(node),
)
else:
ex_node = Node(
target=self.serialize_operator(node.target),
inputs=self.serialize_hoo_inputs(node.args, node.kwargs),
outputs=self.serialize_hoo_outputs(node),
metadata=self.serialize_metadata(node),
is_hop_single_tensor_return=_is_hop_single_tensor_return(node),
)
elif type(node.target) in _serialization_registry:
# Sanity check for unhandled serialization.
assert type(node.target) in _serialization_registry, (
f"{type(node.target)} is not supported in export serialization."
)
handler = _serialization_registry[type(node.target)]
namespace = handler.namespace()
op_name = handler.to_op_name(node.target)
assert isinstance(namespace, str) and isinstance(op_name, str)
assert ":" not in namespace and ":" not in op_name
ex_node = Node(
target=f"#{namespace}:{op_name}",
inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
outputs=self.serialize_outputs(node),
metadata=self.serialize_metadata(node),
)
else:
raise SerializeError(f"Serializing {node.target} is not supported")
self.graph_state.nodes.append(ex_node)
def handle_get_attr(self, node):
log.debug("[handle_get_attr] %s", node.name)
def _output_node_at_index(self, node, index) -> Optional[torch.fx.Node]:
user_node = None
for user in node.users:
assert user.target is operator.getitem, f"{user} is not a getitem node"
if index == user.args[1]:
if user_node is None:
user_node = user
else:
# We want to deduplicate getitem nodes that are trying to
# index to the same index
self.duplicate_getitem_nodes[user.name] = user_node.name
return user_node
def _output_node_name_at_index(self, node, index) -> str:
user_node = self._output_node_at_index(node, index)
if user_node is None:
return f"{node.name}_unused_{index}"
else:
return user_node.name
def serialize_metadata(self, node: torch.fx.Node) -> dict[str, str]:
ret = {}
if stack_trace := node.meta.get("stack_trace"):
ret["stack_trace"] = stack_trace
if nn_module_stack := node.meta.get("nn_module_stack"):
def export_nn_module_stack(val):
assert isinstance(val, tuple) and len(val) == 2
path, ty = val
assert isinstance(path, str)
assert isinstance(ty, str)
return path + "," + ty
# Serialize to "key,orig_path,type_str"
nn_module_list = [
f"{k},{export_nn_module_stack(v)}" for k, v in nn_module_stack.items()
]
ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list)
if source_fn_st := node.meta.get("source_fn_stack"):
source_fn_list = [
f"{source_fn[0]},{self.serialize_operator(source_fn[1])}"
for source_fn in source_fn_st
]
ret["source_fn_stack"] = ST_DELIMITER.join(source_fn_list)
if torch_fn := node.meta.get("torch_fn"):
ret["torch_fn"] = ST_DELIMITER.join(list(torch_fn))
if custom := node.meta.get("custom"):
try:
ret["custom"] = json.dumps(custom)
except Exception as e:
raise SerializeError(
f"Failed to serialize custom metadata for node {node.name} with error {e}"
) from e
return ret
def serialize_script_obj_meta(
self, script_obj_meta: ep.CustomObjArgument
) -> CustomObjArgument:
log.debug("[serialize_script_obj_meta] %s", script_obj_meta)
return CustomObjArgument(
name=script_obj_meta.name,
class_fqn=script_obj_meta.class_fqn,
)
def serialize_sym_op_inputs(self, op, args) -> list[NamedArgument]:
if isinstance(op, torch._ops.OpOverload):
args_names = [arg.name for arg in op._schema.arguments]
else:
assert op in _SYM_OPS
args_names = list(inspect.signature(op).parameters.keys())
serialized_args = []
for args_name, arg in zip(args_names, args):
serialized_args.append(
NamedArgument(
name=args_name,
arg=self.serialize_input(arg),
kind=ArgumentKind.POSITIONAL,
)
)
return serialized_args
def serialize_inputs(
self,
target: Any, # torch._ops.OpOverload and other custom operator types.
args,
kwargs=None,
) -> list[NamedArgument]:
schema = None
serialized_args = []
if isinstance(target, torch._higher_order_ops.torchbind.CallTorchBind):
obj = args[0]
method = args[1]
schema = target.schema(obj, method)
else:
assert isinstance(
target, (torch._ops.OpOverload, *_registered_extension_types())
)
schema = _get_schema_from_target(target)
assert schema is not None
kwargs = kwargs or {}
for i, schema_arg in enumerate(schema.arguments):
if schema_arg.name in kwargs:
serialized_args.append(
NamedArgument(
name=schema_arg.name,
arg=self.serialize_input(
kwargs[schema_arg.name], schema_arg.type
),
kind=ArgumentKind.KEYWORD,
)
)
elif not schema_arg.kwarg_only and i < len(args):
serialized_args.append(
NamedArgument(
name=schema_arg.name,
arg=self.serialize_input(args[i], schema_arg.type),
kind=ArgumentKind.POSITIONAL,
)
)
else:
# We intentionally don't serialize the missing arguments
# with default values
pass
return serialized_args
def serialize_hoo_inputs(self, args, kwargs) -> list[NamedArgument]:
"""
For serializing HOO inputs since HOOs do not have a schema.
"""
inputs = [
NamedArgument(
name="", arg=self.serialize_input(a), kind=ArgumentKind.POSITIONAL
)
for a in args
]
inputs.extend(
[
NamedArgument(
name=name,
arg=self.serialize_input(a),
kind=ArgumentKind.KEYWORD,
)
for name, a in kwargs.items()
]
)
return inputs
def is_inductor_sym_int_arg(self, arg) -> bool:
# This is a special branch for handling SymInt args in inductor's
# ExternalFallbackNode.
# For regular FX graph, SymInt arg should be a fx.Node and should be
# verified with is_sym_int_arg()
return type(arg) is int or isinstance(arg, torch.SymInt)
def is_sym_int_arg(self, arg) -> bool:
return type(arg) is int or (
isinstance(arg, torch.fx.Node)
and arg.name in self.graph_state.sym_int_values
)
def is_sym_float_arg(self, arg) -> bool:
return isinstance(arg, float) or (
isinstance(arg, torch.fx.Node)
and arg.name in self.graph_state.sym_float_values
)
def is_sym_bool_arg(self, arg) -> bool:
return isinstance(arg, bool) or (
isinstance(arg, torch.fx.Node)
and arg.name in self.graph_state.sym_bool_values
)
# should be torch._C.JitType but that annotation is busted
def serialize_input(self, arg, arg_type: Optional[Any] = None) -> Argument:
import torch._inductor.ir as inductor_ir
inductor_tensor_buffers = (
inductor_ir.Buffer,
inductor_ir.ReinterpretView,
)
if isinstance(arg, torch.fx.Node):
if arg.op == "get_attr":
assert isinstance(arg.target, str)
attr = getattr(arg.graph.owning_module, arg.target)
if isinstance(attr, torch.Tensor):
raise SerializeError(
"getattr nodes containing tensors should not appear in the graph"
)
elif isinstance(attr, torch.fx.GraphModule):
with self.save_graph_state():
graph = self.serialize_graph(attr)
return Argument.create(
as_graph=GraphArgument(name=arg.target, graph=graph)
)
elif type(attr).__name__ == "LoweredBackendModule":
# Special handling for executorch_call_delegate HOP
# It's first argument is a LoweredBackendModule, for which we
# serialize name and backend id of the lowered module
module_name = getattr(attr, "module_name", None)
backend_id = getattr(attr, "backend_id", None)
assert module_name is not None, "module_name should not be None"
assert backend_id is not None, "backend_id should not be None"
return Argument.create(as_string=f"{module_name}-{backend_id}")
else:
raise SerializeError(
f"Unsupported getattr attribute {arg.target} with type: {type(attr)}"
)
elif self.is_sym_int_arg(arg):
return Argument.create(
as_sym_int=SymIntArgument.create(as_name=arg.name)
)
elif self.is_sym_float_arg(arg):
return Argument.create(
as_sym_float=SymFloatArgument.create(as_name=arg.name)
)
elif self.is_sym_bool_arg(arg):
return Argument.create(
as_sym_bool=SymBoolArgument.create(as_name=arg.name)
)
elif isinstance(arg.meta["val"], ep.CustomObjArgument):
return Argument.create(
as_custom_obj=CustomObjArgument(
name=arg.name, class_fqn=arg.meta["val"].class_fqn
)
)
elif arg.name in self.duplicate_getitem_nodes:
dedup_name = self.duplicate_getitem_nodes[arg.name]
return Argument.create(as_tensor=TensorArgument(name=dedup_name))
else:
return Argument.create(as_tensor=TensorArgument(name=arg.name))
elif isinstance(arg, inductor_tensor_buffers):
# Other branches are for arguments in fx node.
# This is a special branch for handling buffers (representing tensor arguments)
# for inductor's ExternalFallbackNode
# export_extern_kernel_node() is using this function to serialize arguments
arg_name = arg.get_name()
assert arg_name is not None, "Buffer must have valid name"
return Argument.create(as_tensor=TensorArgument(name=arg_name))
elif isinstance(arg, inductor_ir.TorchBindObject):
# This is a special branch for handling TorchBindObject
# for inductor's ExternalFallbackNode
# export_extern_kernel_node() is using this function to serialize arguments
arg_name = arg.get_name()
assert arg_name is not None, "Buffer must have valid name"
arg_val = arg.get_real_obj()
class_fqn = arg_val._type().qualified_name()
self.custom_objs[arg_name] = arg_val
return Argument.create(as_custom_obj=CustomObjArgument(arg_name, class_fqn))
elif isinstance(arg, torch.SymInt):
# This is a special branch for handling SymInt args in inductor's
# ExternalFallbackNode.
# For regular FX graph, SymInt arg should be a fx.Node with
# self.is_sym_int_arg(arg) being true
return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg)))
elif isinstance(arg, torch.SymFloat):
# This is a special branch for handling SymFloat args in inductor's
# ExternalFallbackNode.
# For regular FX graph, SymInt arg should be a fx.Node with
# self.is_sym_float_arg(arg) being true
return Argument.create(
as_sym_float=SymFloatArgument.create(as_name=str(arg))
)
elif type(arg) is bool:
return Argument.create(as_bool=arg)
elif type(arg) is str:
return Argument.create(as_string=arg)
elif type(arg) is int:
return Argument.create(as_int=arg)
elif type(arg) is float:
return Argument.create(as_float=arg)
elif type(arg) is complex:
return Argument.create(
as_complex=ComplexValue(real=arg.real, imag=arg.imag)
)
elif arg is None:
return Argument.create(as_none=True)
elif isinstance(arg, (list, tuple)):
if len(arg) == 0:
if arg_type is not None:
if isinstance(arg_type, torch.OptionalType):
arg_type = arg_type.getElementType() # type: ignore[assignment]
assert isinstance(arg_type, torch.ListType)
elem_type = arg_type.getElementType()
if isinstance(elem_type, torch.OptionalType):
elem_type = elem_type.getElementType()
if isinstance(elem_type, torch.BoolType):
return Argument.create(as_bools=[])
elif isinstance(elem_type, torch.IntType):
return Argument.create(as_ints=[])
elif isinstance(elem_type, torch.FloatType):
return Argument.create(as_floats=[])
elif isinstance(elem_type, torch.StringType):
return Argument.create(as_strings=[])
elif isinstance(elem_type, torch.TensorType):
return Argument.create(as_tensors=[])
else:
# I believe empty symint lists default to ints, but
# please file an issue if this is not the case
raise SerializeError(f"Empty list with type {elem_type} nyi.")
else:
# We could serialize this by default to a tensor list. This
# is needed in the HOO case
log.warning(
"Unsure how to serialize the given empty list, "
"as we don't know what is the type of this argument. "
"Serializing it as a tensor list by default."
)
return Argument.create(as_tensors=[])
if all(type(a) is bool for a in arg):
return Argument.create(as_bools=list(arg))
elif all(type(a) is int for a in arg):
return Argument.create(as_ints=list(arg))
elif all(type(a) is float for a in arg):
return Argument.create(as_floats=list(arg))
elif all(type(a) is str for a in arg):
return Argument.create(as_strings=list(arg))
elif all(self.is_inductor_sym_int_arg(a) for a in arg):
# This is a special branch for handling SymInt args in inductor's
# ExternalFallbackNode.
# For regular FX graph, SymInt arg should be a fx.Node
values = []
for a in arg:
if isinstance(a, torch.SymInt):
values.append(SymIntArgument.create(as_name=str(a)))
elif type(a) is int:
values.append(SymIntArgument.create(as_int=a))
return Argument.create(as_sym_ints=values)
elif all(isinstance(a, torch.SymFloat) for a in arg):
return Argument.create(
as_sym_floats=[SymFloatArgument.create(as_name=str(a)) for a in arg]
)
elif all(self.is_sym_int_arg(a) for a in arg):
# list of sym_ints
values = []
for a in arg:
if isinstance(a, torch.fx.Node):
values.append(SymIntArgument.create(as_name=a.name))
elif type(a) is int:
values.append(SymIntArgument.create(as_int=a))
return Argument.create(as_sym_ints=values)
elif all(self.is_sym_float_arg(a) for a in arg):
# list of sym_float
values = []
for a in arg:
if isinstance(a, torch.fx.Node):
values.append(SymFloatArgument.create(as_name=a.name))
elif isinstance(a, float):
values.append(SymFloatArgument.create(as_float=a))
return Argument.create(as_sym_floats=values)
elif all(self.is_sym_bool_arg(a) for a in arg):
# list of sym_bools
values = []
for a in arg:
if isinstance(a, torch.fx.Node):
values.append(SymBoolArgument.create(as_name=a.name))
elif isinstance(a, bool):
values.append(SymBoolArgument.create(as_bool=a))
return Argument.create(as_sym_bools=values)
elif all(isinstance(a, torch.fx.Node) for a in arg):
# list of tensors
arguments = []
for a in arg:
if a.op == "get_attr":
raise SerializeError(
"getattr nodes containing tensors should not appear in the graph"
)
arguments.append(TensorArgument(name=a.name))
return Argument.create(as_tensors=arguments)
elif all(isinstance(a, (torch.fx.Node, type(None))) for a in arg):
# list of optional tensors
def serialize_optional_tensor_args(a):
if a is None:
return OptionalTensorArgument.create(as_none=True)
elif isinstance(a, torch.fx.Node):
return OptionalTensorArgument.create(
as_tensor=TensorArgument(name=a.name)
)
else:
raise SerializeError(f"Unsupported list/tuple argument: {a}")
return Argument.create(
as_optional_tensors=list(map(serialize_optional_tensor_args, arg))
)
elif all(isinstance(a, inductor_tensor_buffers) for a in arg):
# list of inductor buffers
return Argument.create(
as_tensors=[TensorArgument(name=a.get_name()) for a in arg],
)
elif all(
isinstance(a, (*inductor_tensor_buffers, type(None))) for a in arg
):
# list of inductor buffers as optional tensors
def serialize_optional_tensor_args(a):
if a is None:
return OptionalTensorArgument.create(as_none=True)
elif isinstance(a, inductor_tensor_buffers):
return OptionalTensorArgument.create(
as_tensor=TensorArgument(name=a.get_name())
)
else:
raise SerializeError(f"Unsupported list/tuple argument: {a}")
return Argument.create(
as_optional_tensors=list(map(serialize_optional_tensor_args, arg))
)
else:
raise SerializeError(
f"Unsupported list/tuple argument type: {[type(a) for a in arg]}"
)
elif isinstance(arg, torch.dtype):
return Argument.create(as_scalar_type=_TORCH_TO_SERIALIZE_DTYPE[arg])
elif isinstance(arg, torch.device):
return Argument.create(as_device=Device(type=arg.type, index=arg.index))
elif isinstance(arg, torch.memory_format):
return Argument.create(
as_memory_format=_TORCH_TO_SERIALIZE_MEMORY_FORMAT[arg]
)
elif isinstance(arg, torch.layout):
return Argument.create(as_layout=_TORCH_TO_SERIALIZE_LAYOUT[arg])
elif isinstance(arg, torch._C.ScriptObject):
if not (
arg._has_method("__getstate__") # type: ignore[attr-defined]
and arg._has_method("__setstate__") # type: ignore[attr-defined]
):
raise SerializeError(
f"Unable to serialize custom class {arg}. Please define "
"serialization methods via def_pickle()."
)
# Custom objects through torchind are serializable with pickle,
# through implementing the .def_pickle function. This should result
# in the object containing a __getstate__ and __setstate__
# serialize/deserialize function.
custom_obj_name = f"_custom_obj_{len(self.custom_objs)}"
self.custom_objs[custom_obj_name] = arg
class_fqn = arg._type().qualified_name() # type: ignore[attr-defined]
return Argument.create(
as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn)
)
elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
return Argument.create(as_operator=self.serialize_operator(arg))
else:
raise SerializeError(
f"Unsupported argument type: {type(arg)} with schema arg_type {arg_type}"
)
def serialize_tensor_output(self, name, meta_val) -> TensorArgument:
assert name not in self.graph_state.tensor_values
self.graph_state.tensor_values[name] = serialize_tensor_meta(meta_val)
return TensorArgument(name=name)
def serialize_sym_int_output(self, name, meta_val) -> SymIntArgument:
assert name not in self.graph_state.sym_int_values
self.graph_state.sym_int_values[name] = serialize_sym_int(meta_val)
return SymIntArgument.create(as_name=name)
def serialize_sym_float_output(self, name, meta_val) -> SymFloatArgument:
assert name not in self.graph_state.sym_float_values
self.graph_state.sym_float_values[name] = serialize_sym_float(meta_val)
return SymFloatArgument.create(as_name=name)
def serialize_sym_bool_output(self, name, meta_val) -> SymIntArgument:
assert name not in self.graph_state.sym_bool_values
self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val)
return SymBoolArgument.create(as_name=name)
def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec:
log.debug("[serialize_input_spec] %s", spec)
if spec.kind == ep.InputKind.USER_INPUT:
if isinstance(spec.arg, ep.ConstantArgument):
if type(spec.arg.value) is int:
constant_spec = ConstantValue.create(as_int=spec.arg.value)
elif type(spec.arg.value) is bool:
constant_spec = ConstantValue.create(as_bool=spec.arg.value)
elif type(spec.arg.value) is str:
constant_spec = ConstantValue.create(as_string=spec.arg.value)
elif type(spec.arg.value) is float:
constant_spec = ConstantValue.create(as_float=spec.arg.value)
elif spec.arg.value is None:
constant_spec = ConstantValue.create(as_none=True)
else:
raise SerializeError(
f"Unhandled constant input {spec.arg.value} to serialize"
)
return InputSpec.create(
constant_input=InputToConstantInputSpec(
name=spec.arg.name, value=constant_spec
)
)
else:
return InputSpec.create(
user_input=UserInputSpec(arg=self.serialize_argument_spec(spec.arg))
)
elif spec.kind == ep.InputKind.PARAMETER:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
return InputSpec.create(
parameter=InputToParameterSpec(
arg=TensorArgument(name=spec.arg.name),
parameter_name=spec.target,
)
)
elif spec.kind == ep.InputKind.BUFFER:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
assert spec.persistent is not None
return InputSpec.create(
buffer=InputToBufferSpec(
arg=TensorArgument(name=spec.arg.name),
buffer_name=spec.target,
persistent=spec.persistent,
)
)
elif spec.kind == ep.InputKind.CONSTANT_TENSOR:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
return InputSpec.create(
tensor_constant=InputToTensorConstantSpec(
arg=TensorArgument(name=spec.arg.name),
tensor_constant_name=spec.target,
)
)
elif spec.kind == ep.InputKind.CUSTOM_OBJ:
assert spec.target is not None
assert isinstance(spec.arg, ep.CustomObjArgument)
return InputSpec.create(
custom_obj=InputToCustomObjSpec(
arg=CustomObjArgument(
name=spec.arg.name, class_fqn=spec.arg.class_fqn
),
custom_obj_name=spec.target,
)
)
elif spec.kind == ep.InputKind.TOKEN:
assert isinstance(spec.arg, ep.TokenArgument)
return InputSpec.create(
token=InputTokenSpec(
arg=TokenArgument(name=spec.arg.name),
)
)
else:
raise AssertionError(f"Unknown argument kind: {spec}")
def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec:
log.debug("[serialize_output_spec] %s", spec)
if spec.kind == ep.OutputKind.USER_OUTPUT:
return OutputSpec.create(
user_output=UserOutputSpec(arg=self.serialize_argument_spec(spec.arg))
)
elif spec.kind == ep.OutputKind.LOSS_OUTPUT:
assert isinstance(spec.arg, ep.TensorArgument)
return OutputSpec.create(
loss_output=LossOutputSpec(arg=TensorArgument(name=spec.arg.name))
)
elif spec.kind == ep.OutputKind.BUFFER_MUTATION:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
return OutputSpec.create(
buffer_mutation=BufferMutationSpec(
arg=TensorArgument(name=spec.arg.name),
buffer_name=spec.target,
)
)
elif spec.kind == ep.OutputKind.PARAMETER_MUTATION:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
return OutputSpec.create(
parameter_mutation=ParameterMutationSpec(
arg=TensorArgument(name=spec.arg.name),
parameter_name=spec.target,
)
)
elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
return OutputSpec.create(
gradient_to_parameter=GradientToParameterSpec(
arg=TensorArgument(name=spec.arg.name),
parameter_name=spec.target,
)
)
elif spec.kind == ep.OutputKind.GRADIENT_TO_USER_INPUT:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
return OutputSpec.create(
gradient_to_user_input=GradientToUserInputSpec(
arg=TensorArgument(name=spec.arg.name),
user_input_name=spec.target,
)
)
elif spec.kind == ep.OutputKind.USER_INPUT_MUTATION:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
return OutputSpec.create(
user_input_mutation=UserInputMutationSpec(
arg=TensorArgument(name=spec.arg.name),
user_input_name=spec.target,
)
)
elif spec.kind == ep.OutputKind.TOKEN:
assert isinstance(spec.arg, ep.TokenArgument)
return OutputSpec.create(
token=OutputTokenSpec(
arg=TokenArgument(name=spec.arg.name),
)
)
else:
raise AssertionError(f"Unknown argument kind: {spec}")
def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature:
log.debug("\n[serialize_signature]")
return GraphSignature(
input_specs=[self.serialize_input_spec(s) for s in sig.input_specs],
output_specs=[self.serialize_output_spec(s) for s in sig.output_specs],
)
def serialize_argument_spec(self, x: ep.ArgumentSpec) -> Argument:
if isinstance(x, ep.TensorArgument):
return Argument.create(as_tensor=TensorArgument(name=x.name))
elif isinstance(x, ep.SymIntArgument):
return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name))
elif isinstance(x, ep.SymFloatArgument):
return Argument.create(as_sym_float=SymFloatArgument.create(as_name=x.name))
elif isinstance(x, ep.ConstantArgument):
return self.serialize_input(x.value)
elif isinstance(x, ep.CustomObjArgument):
return Argument.create(
as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn)
)
else:
raise AssertionError("TODO")
def serialize_treespec(self, treespec):
# We want to additionally save all the field names of the namedtuples in
# case users want to check that the treespec types are equivalent
def store_namedtuple_fields(ts):
if ts.type is None:
return
if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type):
serialized_type_name = pytree.SUPPORTED_SERIALIZED_TYPES[
ts.context
].serialized_type_name
if serialized_type_name in self.treespec_namedtuple_fields:
field_names = self.treespec_namedtuple_fields[
serialized_type_name
].field_names
if field_names != ts.context._fields:
raise SerializeError(
f"The given TreeSpec's namedtuple type {ts.context} "
f"was found to have field names {ts.context._fields} "
f"but somehow previously was found to have field names {field_names}."
)
else:
self.treespec_namedtuple_fields[serialized_type_name] = (
NamedTupleDef(field_names=ts.context._fields)
)
for child in ts.children_specs:
store_namedtuple_fields(child)
serialized_treespec = treespec_dumps(treespec, TREESPEC_VERSION)
store_namedtuple_fields(treespec)
return serialized_treespec
def serialize_module_call_signature(
self, module_call_signature: ep.ModuleCallSignature
) -> ModuleCallSignature:
log.debug("[serialize_module_call_signature] %s", module_call_signature)
return ModuleCallSignature(
inputs=[
self.serialize_argument_spec(x) for x in module_call_signature.inputs
],
outputs=[
self.serialize_argument_spec(x) for x in module_call_signature.outputs
],
in_spec=self.serialize_treespec(module_call_signature.in_spec),
out_spec=self.serialize_treespec(module_call_signature.out_spec),
forward_arg_names=names
if (names := module_call_signature.forward_arg_names)
else None,
)
def serialize_module_call_graph(
self, module_call_graph: list[ep.ModuleCallEntry]
) -> list[ModuleCallEntry]:
log.debug("\n[serialize_module_call_graph]")
return [
ModuleCallEntry(
fqn=entry.fqn,
signature=(
self.serialize_module_call_signature(entry.signature)
if entry.signature
else None
),
)
for entry in module_call_graph
]
def serialize_outputs(self, node: torch.fx.Node) -> list[Argument]:
"""For a given node, return the dataclass representing its output values.
[NOTE: Multiple outputs] We handle aggregates differently than FX. For
FX, it looks like:
x = call_function("multiple_return", ...)
element0 = call_function(getitem, x, 0)
foo = call_function("use_output", element0)
We do not want the intermediate `getitem` call, so our serialized thing looks like:
element0, element1, element2 = call_function("multiple_return", ...)
foo = call_function("use_output", element0)
We want names to be consistent across these two schemes, so that we can
mostly reuse the names coming from FX. This function computes a mapping from
the FX representation to our representation, preserving the names.
"""
def _is_single_tensor_list_return(target: Any) -> bool:
schema = _get_schema_from_target(target)
returns = schema.returns
if len(returns) != 1:
return False
return_type = returns[0].real_type
return isinstance(return_type, torch.ListType) and isinstance(
return_type.getElementType(), torch.TensorType
)
assert node.op == "call_function" and isinstance(
node.target, (torch._ops.OpOverload, *_registered_extension_types())
)
schema = _get_schema_from_target(node.target)
returns = schema.returns
if len(returns) == 0:
return []
meta_val = node.meta["val"]
# Check single value return
if _is_single_tensor_list_return(node.target):
# e.g "-> Tensor[]"
tensor_args = []
for idx, meta in enumerate(meta_val):
name = self._output_node_name_at_index(node, idx)
tensor_args.append(self.serialize_tensor_output(name, meta))
return [Argument.create(as_tensors=tensor_args)]
elif len(returns) == 1:
return [self.serialize_output(node.name, meta_val)]
# There are a two possibilities at this point:
# - This operator returns a tuple of Tensors, e.g. "-> (Tensor, Tensor)"
# - This operator returns a tuple of mixed of Tensor and Tensors, e.g. "-> (Tensor, Tensor[])"
#
# Either way, start by gathering a list of TensorArguments with the correct names.
# For consistent naming with FX, consult the downstream `getitem` node and
# make sure our outputs have the same name.
output_arguments = []
for idx, (meta, return_schema) in enumerate(zip(meta_val, returns)):
if meta is None:
assert isinstance(
return_schema.real_type, (torch.OptionalType, torch.TensorType)
)
# When the return type is annotated as Tensor type, the op can also return an
# undefined Tensor which will be implicitly converted to None in Python.
output_arguments.append(Argument.create(as_none=True))
elif isinstance(meta, FakeTensor):
assert isinstance(
return_schema.real_type, (torch.OptionalType, torch.TensorType)
)
name = self._output_node_name_at_index(node, idx)
output_arguments.append(self.serialize_output(name, meta))
elif isinstance(meta, list):
# for List[Tensor] return type
assert isinstance(
return_schema.real_type, torch.ListType
) and isinstance(
return_schema.real_type.getElementType(), torch.TensorType
)
user_node = self._output_node_at_index(node, idx)
assert user_node is not None
args = []
for i, m in enumerate(meta):
if m is None:
continue
sub_user_node_name = self._output_node_name_at_index(user_node, i)
args.append(self.serialize_tensor_output(sub_user_node_name, m))
output_arguments.append(Argument.create(as_tensors=args))
elif isinstance(meta, (int, SymInt, float, SymFloat)):
user_node_name = self._output_node_name_at_index(node, idx)
output_arguments.append(self.serialize_output(user_node_name, meta))
else:
raise ValueError(
f"Unhandled output type {type(meta)} from node {node.format_node()}"
)
return output_arguments
def serialize_hoo_outputs(self, node: torch.fx.Node) -> list[Argument]:
"""
For serializing HOO outputs since HOOs do not have a schema.
"""
meta_val = node.meta["val"]
if isinstance(meta_val, tuple):
outputs = []
for i, element_meta_val in enumerate(meta_val):
user_node = self._output_node_at_index(node, i)
if isinstance(element_meta_val, list):
# e.g "-> Tensor[]"
assert user_node is not None
tensors = []
for j, m in enumerate(element_meta_val):
if not isinstance(m, torch.Tensor):
raise SerializeError(
f"Serialize list output with type {type(m)} nyi"
)
name = self._output_node_name_at_index(user_node, j)
tensors.append(self.serialize_tensor_output(name, m))
outputs.append(Argument.create(as_tensors=tensors))
else:
name = (
user_node.name
if user_node is not None
else f"{node.name}_unused_{i}"
)
outputs.append(self.serialize_output(name, element_meta_val))
return outputs
elif isinstance(meta_val, dict):
tensor_args = []
# use the dict key as the idx
for idx, meta in meta_val.items():
if not isinstance(meta, torch.Tensor):
raise SerializeError(
f"Serialize list output with type {type(meta)} nyi"
)
name = self._output_node_name_at_index(node, idx)
tensor_args.append(self.serialize_tensor_output(name, meta))
return [Argument.create(as_tensors=tensor_args)]
else:
return [self.serialize_output(node.name, meta_val)]
def serialize_output(self, name: str, meta_val: Any) -> Argument:
# Check single value return
if meta_val is None:
return Argument.create(as_none=True)
if isinstance(meta_val, torch.Tensor):
# e.g "-> Tensor"
return Argument.create(
as_tensor=self.serialize_tensor_output(name, meta_val)
)
elif isinstance(meta_val, (bool, torch.SymBool)):
# e.g "-> SymBool"
return Argument.create(
as_sym_bool=self.serialize_sym_bool_output(name, meta_val)
)
elif isinstance(meta_val, (int, torch.SymInt)):
# e.g "-> SymInt"
assert not isinstance(meta_val, bool)
return Argument.create(
as_sym_int=self.serialize_sym_int_output(name, meta_val)
)
elif isinstance(meta_val, (float, torch.SymFloat)):
# e.g "-> SymFloat"
return Argument.create(
as_sym_float=self.serialize_sym_float_output(name, meta_val)
)
# list outputs should've been handled earlier
raise SerializeError(f"Unable to serialize output {meta_val}")
def _handle_getitem_users(self, node: torch.fx.Node) -> list[TensorArgument]:
meta_val = node.meta["val"]
idx_to_name = {}
for user in node.users:
assert user.target is operator.getitem, (
f"User node {user} of {node} is incorrect"
)
idx_to_name[user.args[1]] = user.name
for idx, _ in enumerate(meta_val):
# FX does not emit a getitem node for any outputs that are unused.
# However, we need a name for them so that the number of outputs will
# correctly match the schema. Just assign a dummy name.
if idx not in idx_to_name:
idx_to_name[idx] = f"{node.name}_unused_{idx}"
arg_list = []
for i, element_meta_val in enumerate(meta_val):
arg_list.append(
self.serialize_tensor_output(idx_to_name[i], element_meta_val)
)
return arg_list
def serialize_graph(self, graph_module: torch.fx.GraphModule) -> Graph:
assert isinstance(graph_module, torch.fx.GraphModule)
log.debug(
"[serialize_graph]\n\n%s", graph_module.print_readable(print_output=False)
)
for node in graph_module.graph.nodes:
try:
getattr(self, f"handle_{node.op}")(node)
except Exception as e:
raise SerializeError(
f"Failed serializing node {node} in graph: {node.format_node()}\n Original exception {traceback.format_exc()}"
) from e
return Graph(
inputs=self.graph_state.inputs,
nodes=self.graph_state.nodes,
tensor_values=self.graph_state.tensor_values,
sym_int_values=self.graph_state.sym_int_values,
sym_float_values=self.graph_state.sym_float_values,
sym_bool_values=self.graph_state.sym_bool_values,
custom_obj_values=self.graph_state.custom_obj_values,
outputs=self.graph_state.outputs,
is_single_tensor_return=self.graph_state.is_single_tensor_return,
)
def serialize_graph_module_metadata(self, meta: dict[str, Any]):
ret = {}
if custom := meta.get("custom"):
log.debug("\n[serialize_graph_module_metadata] %s", custom)
try:
ret["custom"] = json.dumps(custom)
except Exception as e:
raise SerializeError(
f"Failed to serialize custom metadata for graph with error {e}"
) from e
return ret
def serialize(self, graph_module: torch.fx.GraphModule) -> GraphModule:
log.debug("\n[serialize]")
graph = self.serialize_graph(graph_module)
return GraphModule(
graph=graph,
signature=self.serialize_signature(self.graph_signature),
module_call_graph=self.serialize_module_call_graph(self.module_call_graph),
metadata=self.serialize_graph_module_metadata(graph_module.meta),
treespec_namedtuple_fields=self.treespec_namedtuple_fields,
)
@final
class ExportedProgramSerializer(metaclass=Final):
def __init__(
self,
opset_version: Optional[dict[str, int]] = None,
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
):
self.opset_version: dict[str, int] = {}
if opset_version:
self.opset_version.update(opset_version)
if "aten" not in self.opset_version:
self.opset_version["aten"] = torch._C._get_max_operator_version()
self.pickle_protocol = pickle_protocol
def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram:
"""
Args:
exported_program: Exported Program to serialize
"""
exported_program.validate()
gm_serializer = GraphModuleSerializer(
exported_program.graph_signature, exported_program.module_call_graph
)
serialized_graph_module = gm_serializer.serialize(exported_program.graph_module)
serialized_range_constraints = serialize_range_constraints(
exported_program.range_constraints
)
# TODO: Directly serialize exported_program.constants once
# CustomClassHolders get stored in the ExportedProgram rather than in
# the graph
constants: dict[str, Any] = gm_serializer.custom_objs.copy()
for n, t in exported_program.constants.items():
assert n not in constants
constants[n] = t
serialized_ep = ExportedProgram(
graph_module=serialized_graph_module,
opset_version=self.opset_version,
range_constraints=serialized_range_constraints,
schema_version=SchemaVersion(
major=SCHEMA_VERSION[0],
minor=SCHEMA_VERSION[1],
),
verifiers=[v.dialect for v in exported_program.verifiers],
torch_version=torch.__version__,
guards_code=exported_program._guards_code,
)
# Test canonical form is well defined.
canonicalize(serialized_ep, set(constants.keys()))
# Proxy cannot be dumped, so we remove them.
new_state_dict = remove_proxy_from_state_dict(
exported_program.state_dict, in_place=False
)
return _SerializedProgram(
serialized_ep,
serialize_torch_artifact(new_state_dict, self.pickle_protocol),
serialize_torch_artifact(constants, self.pickle_protocol),
serialize_torch_artifact(
exported_program.example_inputs, self.pickle_protocol
),
)
@final
class GraphModuleDeserializer(metaclass=Final):
@dataclasses.dataclass
class Result:
graph_module: torch.fx.GraphModule
signature: ep.ExportGraphSignature
module_call_graph: list[ep.ModuleCallEntry]
names_to_symbols: dict[str, sympy.Symbol]
state_dict: dict[str, Union[torch.Tensor, torch.nn.Parameter]]
constants: dict[str, _ConstantAttributeType]
example_inputs: Optional[tuple[tuple[torch.Tensor, ...], dict[str, Any]]]
def __init__(self) -> None:
self.serialized_name_to_node: dict[str, torch.fx.Node] = {}
self.serialized_name_to_meta: LazyMap = LazyMap() # str -> MetaType
self.graph = torch.fx.Graph()
self.module = torch.nn.Module()
@contextmanager
def save_graph_module(self) -> Iterator[None]:
saved = (
self.graph,
self.module,
self.serialized_name_to_node,
self.serialized_name_to_meta,
self.unbacked_symbols,
)
self.graph = torch.fx.Graph()
self.module = torch.nn.Module()
self.serialized_name_to_node = {}
self.serialized_name_to_meta = LazyMap()
self.unbacked_symbols: set[sympy.Symbol] = set()
try:
yield
finally:
(
self.graph,
self.module,
self.serialized_name_to_node,
self.serialized_name_to_meta,
self.unbacked_symbols,
) = saved
def deserialize_extension_operator(self, serialized_target: str):
namespace, op_name = serialized_target.split(":")
namespace = namespace[1:] # starting with #
handler = _deserialization_registry[namespace]
return handler.from_op_name(op_name)
def deserialize_operator(self, serialized_target: str):
if serialized_target.startswith(
"_operator"
): # TODO(zhxchen17) Follow up on this.
module = operator
serialized_target_names = serialized_target.split(".")[1:]
elif serialized_target.startswith("torch"):
module = torch # type: ignore[misc]
serialized_target_names = serialized_target.split(".")[1:]
elif serialized_target.startswith("math"):
module = math # type: ignore[misc]
serialized_target_names = serialized_target.split(".")[1:]
elif serialized_target.startswith("#"):
return self.deserialize_extension_operator(serialized_target)
else: # TODO(zhxchen17) Don't catch all here.
return serialized_target
target = module
for name in serialized_target_names:
if not hasattr(target, name):
return serialized_target
else:
target = getattr(target, name)
return target
def _parse_sym_expr(
self, expr_str: str, hint: Optional[Union[int, bool, float]] = None
) -> sympy.Expr:
"""
Parses and does bottom-up processing of sympy.Expr nodes,
populating ShapeEnv & caching symbols as needed.
"""
def _process_sym_expr(
sym: sympy.Expr, hint: Optional[Union[int, bool, float]] = None
) -> sympy.Expr:
if sym.is_Integer or sym.is_Float or sym.is_Boolean: # base case
return sym
else: # recursive case
# important to use str(expr) and not _print_sympy(),
# str(expr) is key for self.symbol_name_to_range
expr_str = str(sym)
for arg in sym.args:
self._parse_sym_expr(arg)
# symbol caching
if expr_str in self.symbol_name_to_symbol:
sym = self.symbol_name_to_symbol[expr_str]
else:
self.symbol_name_to_symbol[expr_str] = sym
if isinstance(sym, sympy.Symbol) and symbolic_shapes.symbol_is_type(
sym, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT)
):
self.unbacked_symbols.add(sym)
# hints
if hint is not None and sym not in self.shape_env.var_to_val:
self.shape_env.add_var_to_val(sym, hint) # type: ignore[arg-type]
# ValueRanges
if vr := self.symbol_name_to_range.get(expr_str):
self.shape_env.constrain_symbol_range(
sym,
compiler_min=vr.lower, # type: ignore[arg-type]
compiler_max=vr.upper, # type: ignore[arg-type]
)
# ShapeEnv meta
if isinstance(sym, sympy.Symbol):
self.shape_env.var_to_stack[sym] = CapturedTraceback.extract(skip=1)
return sym
expr = sympy.sympify(
expr_str,
locals={**self.sympy_functions, **self.symbol_name_to_symbol},
)
return _process_sym_expr(expr, hint)
def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]:
val = s.value
if s.type == "as_expr":
if val.hint is None:
hint = None
else:
assert val.hint.type == "as_int"
hint = val.hint.value
sym = self._parse_sym_expr(val.expr_str, hint)
return self.shape_env.create_symintnode(sym, hint=hint)
elif s.type == "as_int":
assert type(val) is int
return val
else:
raise SerializeError(
f"SymInt has invalid field type {s.type} with value {s.value}"
)
def deserialize_sym_float(self, s: SymFloat) -> Union[float, torch.SymFloat]:
val = s.value
if s.type == "as_expr":
hint = val.hint.as_float if val.hint else None
sym = self._parse_sym_expr(val.expr_str, hint)
return self.shape_env.create_symfloatnode(sym, hint=hint)
elif s.type == "as_float":
assert isinstance(val, float)
return val
else:
raise SerializeError(
f"SymFloat has invalid field type {s.type} with value {s.value}"
)
def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]:
val = s.value
if s.type == "as_expr":
expr = self._parse_sym_expr(val.expr_str)
return self.shape_env.create_symboolnode(expr)
elif s.type == "as_bool":
assert isinstance(val, bool)
return val
else:
raise SerializeError(
f"SymBool has invalid field type {s.type} with value {s.value}"
)
def deserialize_tensor_meta(
self,
tensor_meta: TensorMeta,
) -> FakeTensor:
with self.fake_tensor_mode:
return cast(
FakeTensor,
torch.empty_strided(
tuple(self.deserialize_sym_int(val) for val in tensor_meta.sizes), # type: ignore[misc]
tuple(self.deserialize_sym_int(val) for val in tensor_meta.strides), # type: ignore[misc]
device=deserialize_device(tensor_meta.device),
dtype=_SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype],
requires_grad=tensor_meta.requires_grad,
),
)
def deserialize_script_obj_meta(
self, script_obj_meta: CustomObjArgument
) -> ep.CustomObjArgument:
return ep.CustomObjArgument(
name=script_obj_meta.name,
class_fqn=script_obj_meta.class_fqn,
)
def deserialize_graph_output(self, output) -> Optional[Union[torch.fx.Node, int]]:
if output.type == "as_tensor":
return self.serialized_name_to_node[output.as_tensor.name]
elif output.type == "as_sym_int":
return self.serialized_name_to_node[output.as_sym_int.as_name]
elif output.type == "as_sym_bool":
return self.serialized_name_to_node[output.as_sym_bool.as_name]
elif output.type == "as_sym_float":
return self.serialized_name_to_node[output.as_sym_float.as_name]
elif output.type == "as_int":
return output.as_int
elif output.type == "as_float":
return output.as_float
elif output.type == "as_bool":
return output.as_bool
elif output.type == "as_none":
return None
else:
raise SerializeError(f"Unable to deserialize output node {output}")
def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
log.debug("\n[deserialize_graph]")
# Handle the tensor metas.
for name, tensor_value in serialized_graph.tensor_values.items():
log.debug("[deserialize_tensor_meta] %s (input): %s", name, tensor_value)
self.serialized_name_to_meta[name] = (
lambda v=tensor_value: self.deserialize_tensor_meta(v)
)
for name, sym_int_value in serialized_graph.sym_int_values.items():
log.debug("[deserialize_sym_int] %s (input): %s", name, sym_int_value)
self.serialized_name_to_meta[name] = (
lambda v=sym_int_value: self.deserialize_sym_int(v)
)
for name, sym_float_value in serialized_graph.sym_float_values.items():
log.debug("[deserialize_sym_float] %s (input): %s", name, sym_float_value)
self.serialized_name_to_meta[name] = (
lambda v=sym_float_value: self.deserialize_sym_float(v)
)
for name, sym_bool_value in serialized_graph.sym_bool_values.items():
log.debug("[deserialize_sym_bool] %s (input): %s", name, sym_bool_value)
self.serialized_name_to_meta[name] = (
lambda v=sym_bool_value: self.deserialize_sym_bool(v)
)
for name, script_obj_meta in serialized_graph.custom_obj_values.items():
log.debug("[deserialize_script_obj_meta] %s", script_obj_meta)
self.serialized_name_to_meta[name] = (
lambda v=script_obj_meta: self.deserialize_script_obj_meta(v)
)
log.debug("\n[deserialize graph nodes]")
# Inputs: convert to placeholder nodes in FX.
for i, input_ in enumerate(serialized_graph.inputs):
log.debug("[deserialize input] %s", input_)
if input_.type in ("as_tensor", "as_custom_obj"):
node_name = input_.value.name
placeholder_node = self.graph.placeholder(node_name)
# FX might declare a name illegal (e.g. some nn.Modules use "input" as forward() arguments)
# we will overwrite it
placeholder_node.name = node_name
self.sync_fx_node(node_name, placeholder_node)
elif input_.type == "as_sym_int":
if input_.value.type == "as_name":
node_name = input_.value.as_name
placeholder_node = self.graph.placeholder(node_name)
# FX might declare a name illegal (e.g. some nn.Modules use "input" as forward() arguments)
# we will overwrite it
placeholder_node.name = node_name
self.sync_fx_node(node_name, placeholder_node)
else:
raise SerializeError(
f"Deserializing a constant symint {input_.value} as an input"
)
elif input_.type in (
"as_int",
"as_float",
"as_bool",
"as_none",
"as_string",
):
node_name = self.signature.input_specs[i].arg.name or f"arg{i}"
placeholder_node = self.graph.placeholder(node_name)
placeholder_node.meta["val"] = self.deserialize_input(input_)
else:
raise SerializeError(f"Invalid input type {input_}")
# Nodes: convert to call_function nodes.
for serialized_node in serialized_graph.nodes:
try:
target = self.deserialize_operator(serialized_node.target)
self.deserialize_node(serialized_node, target)
except Exception as e:
raise SerializeError(
f"Failed deserializing node {serialized_node}\n Original exception {traceback.format_exc()}"
) from e
# Outputs: convert to a single `output` node.
outputs = []
for output in serialized_graph.outputs:
log.debug("[deserialize output] %s", output)
outputs.append(self.deserialize_graph_output(output))
if serialized_graph.is_single_tensor_return:
assert len(outputs) == 1
outputs = outputs[0] # type: ignore[assignment]
else:
outputs = tuple(outputs) # type: ignore[assignment]
output_node = self.graph.output(outputs)
if serialized_graph.is_single_tensor_return:
output_node.meta["val"] = output_node.args[0].meta["val"]
else:
output_node.meta["val"] = tuple(
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
for arg in output_node.args[0]
)
# recompute unbacked bindings
for node in self.graph.nodes:
if (val := node.meta.get("val")) is not None and (
unbacked_bindings := symbolic_shapes._free_unbacked_symbols_with_path(
val,
(),
shape_env=self.shape_env,
pending=self.unbacked_symbols,
simplify=True,
)
):
node.meta["unbacked_bindings"] = unbacked_bindings
assert len(self.unbacked_symbols) == 0
return self.graph
def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
def _is_single_tensor_return(target) -> bool:
schema = _get_schema_from_target(target)
returns = schema.returns
return len(returns) == 1 and isinstance(
returns[0].real_type, torch.TensorType
)
if (
target in _SYM_OPS
or target
== torch.ops.aten.item.default # this can produce either SymInt or SymBool
):
name = serialized_node.outputs[0].value.as_name
args = self.deserialize_sym_op_inputs(serialized_node.inputs)
fx_node = self.graph.create_node("call_function", target, args, {}, name)
self.deserialize_sym_op_outputs(serialized_node, fx_node)
elif (
target
is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional
):
raise SerializeError(
"deserialize nyi for torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional"
)
elif isinstance(target, torch._ops.HigherOrderOperator):
args, kwargs = self.deserialize_hoo_inputs(serialized_node.inputs)
metadata = self.deserialize_metadata(serialized_node.metadata)
for x in (*args, *kwargs.values()):
if isinstance(x, torch.fx.Node) and x.op == "get_attr":
# this means that we have deserialized a graph argument, but
# unfortunately the schema for it does not include metadata;
# so we reuse the metadata of the HOP call for such arguments
x.meta.update(metadata)
# If a serialized HOP node has a length=1 outputs of type `as_tensor``.
# There could be two cases:
# (1) The HOP node returns a single tensor
# (2) The HOP node returns a tuple containing a single tensor
# We distinguish (1) and (2) by the `is_single_tensor_return`
# field in the schema of Node
# For BC, getattr() will return True if `is_single_tensor_return` doesn't
# exist. This is because prior to adding `is_single_tensor_return`,
# only (1) could happen as we handle (2) with type `as_tensors`
name = (
serialized_node.outputs[0].as_tensor.name
if len(serialized_node.outputs) == 1
and hasattr(serialized_node.outputs[0], "as_tensor")
and getattr(serialized_node, "is_hop_single_tensor_return", True)
else None
)
fx_node = self.graph.create_node(
"call_function", target, args, kwargs, name
)
self.deserialize_outputs(serialized_node, fx_node)
fx_node.meta.update(metadata)
elif isinstance(
target, (torch._ops.OpOverload, *_registered_extension_types())
):
# For convenience: if this node returns a single tensor, name the
# newly-created node after it. This ensures that these tensor values
# have names that are consistent with serialized.
name = (
serialized_node.outputs[0].as_tensor.name
if _is_single_tensor_return(target)
else None # FX will generate a name for us.
)
args, kwargs = self.deserialize_inputs(target, serialized_node)
fx_node = self.graph.create_node(
"call_function", target, args, kwargs, name
)
self.deserialize_outputs(serialized_node, fx_node)
else:
_additional_msg = (
(
f"We failed to resolve {target} to an operator. "
+ "If it's a custom op/custom triton op, this is usually because the custom op is not registered"
+ " when deserializing. Please import the custom op to register it before deserializing."
+ " Otherwise, please file an issue on github."
)
if isinstance(target, str)
else ""
)
raise SerializeError(
_additional_msg
+ f" Unsupported target type for node {serialized_node}: {type(target)}."
)
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
log.debug(
"[deserialize_node] %s: %s(%s, {%s}) -> %s",
fx_node.name,
fx_node.target,
fx_node.args,
fx_node.kwargs,
fx_node.meta.get("val"),
)
# handle ShapeEnv asserts
if target == torch.ops.aten._assert_scalar.default:
if not isinstance((arg := fx_node.args[0]), bool):
expr = arg.meta["val"] # type: ignore[union-attr]
if isinstance(expr, torch.SymBool):
self.shape_env.guard_or_defer_runtime_assert(
expr.node.expr, "", fx_node
)
elif target == torch.ops.aten.sym_constrain_range_for_size.default:
sym = fx_node.args[0].meta["val"] # type: ignore[union-attr]
if isinstance(sym, torch.SymInt):
self.shape_env._constrain_range_for_size(sym.node.expr)
# handle nn_module_stack; serialization throws away empty dicts
if (
fx_node.op not in ["placeholder", "output"]
and "nn_module_stack" not in fx_node.meta
):
fx_node.meta["nn_module_stack"] = {}
def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec:
log.debug("[deserialize_input_spec] %s", i)
if i.type == "user_input":
return ep.InputSpec(
kind=ep.InputKind.USER_INPUT,
arg=self.deserialize_argument_spec(i.user_input.arg),
target=None,
)
elif i.type == "parameter":
return ep.InputSpec(
kind=ep.InputKind.PARAMETER,
arg=ep.TensorArgument(name=i.parameter.arg.name),
target=i.parameter.parameter_name,
)
elif i.type == "buffer":
return ep.InputSpec(
kind=ep.InputKind.BUFFER,
arg=ep.TensorArgument(name=i.buffer.arg.name),
target=i.buffer.buffer_name,
persistent=i.buffer.persistent,
)
elif i.type == "tensor_constant":
return ep.InputSpec(
kind=ep.InputKind.CONSTANT_TENSOR,
arg=ep.TensorArgument(name=i.tensor_constant.arg.name),
target=i.tensor_constant.tensor_constant_name,
)
elif i.type == "custom_obj":
return ep.InputSpec(
kind=ep.InputKind.CUSTOM_OBJ,
arg=ep.CustomObjArgument(
name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn
),
target=i.custom_obj.custom_obj_name,
)
elif i.type == "token":
return ep.InputSpec(
kind=ep.InputKind.TOKEN,
arg=ep.TokenArgument(name=i.token.arg.name),
target=None,
)
elif i.type == "constant_input":
return ep.InputSpec(
kind=ep.InputKind.USER_INPUT,
arg=ep.ConstantArgument(
name=i.constant_input.name,
value=self.deserialize_constant_input(i.constant_input.value),
),
target=None,
)
else:
raise AssertionError(f"Unknown input spec {i}")
def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec:
log.debug("[deserialize_output_spec] %s", o)
if o.type == "user_output":
return ep.OutputSpec(
kind=ep.OutputKind.USER_OUTPUT,
arg=self.deserialize_argument_spec(o.user_output.arg),
target=None,
)
elif o.type == "loss_output":
return ep.OutputSpec(
kind=ep.OutputKind.LOSS_OUTPUT,
arg=ep.TensorArgument(name=o.loss_output.arg.name),
target=None,
)
elif o.type == "buffer_mutation":
return ep.OutputSpec(
kind=ep.OutputKind.BUFFER_MUTATION,
arg=ep.TensorArgument(name=o.buffer_mutation.arg.name),
target=o.buffer_mutation.buffer_name,
)
elif o.type == "parameter_mutation":
return ep.OutputSpec(
kind=ep.OutputKind.PARAMETER_MUTATION,
arg=ep.TensorArgument(name=o.parameter_mutation.arg.name),
target=o.parameter_mutation.parameter_name,
)
elif o.type == "gradient_to_parameter":
return ep.OutputSpec(
kind=ep.OutputKind.GRADIENT_TO_PARAMETER,
arg=ep.TensorArgument(name=o.gradient_to_parameter.arg.name),
target=o.gradient_to_parameter.parameter_name,
)
elif o.type == "gradient_to_user_input":
return ep.OutputSpec(
kind=ep.OutputKind.GRADIENT_TO_USER_INPUT,
arg=ep.TensorArgument(name=o.gradient_to_user_input.arg.name),
target=o.gradient_to_user_input.user_input_name,
)
elif o.type == "user_input_mutation":
return ep.OutputSpec(
kind=ep.OutputKind.USER_INPUT_MUTATION,
arg=ep.TensorArgument(name=o.user_input_mutation.arg.name),
target=o.user_input_mutation.user_input_name,
)
elif o.type == "token":
return ep.OutputSpec(
kind=ep.OutputKind.TOKEN,
arg=ep.TokenArgument(name=o.token.arg.name),
target=None,
)
else:
raise AssertionError(f"Unknown output spec {o}")
def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature:
log.debug("\n[deserialize_signature]")
return ep.ExportGraphSignature(
input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs],
output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs],
)
def deserialize(
self,
serialized_graph_module: GraphModule,
serialized_state_dict: Union[dict[str, torch.Tensor], bytes],
constants: Union[dict[str, Any], bytes],
example_inputs: Optional[
Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes]
] = None,
symbol_name_to_range: Optional[dict[str, symbolic_shapes.ValueRanges]] = None,
) -> Result:
global _CURRENT_DESERIALIZER
assert _CURRENT_DESERIALIZER is None
_CURRENT_DESERIALIZER = self
try:
log.debug("\n[deserialize]")
self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True)
self.fake_tensor_mode = FakeTensorMode(
allow_fallback_kernels=False,
allow_non_fake_inputs=True,
shape_env=self.shape_env,
)
self.sympy_functions = {
# all torch.utils._sympy.functions should go here
# TODO(avik): find a better way to keep this collection in sync;
# e.g.., `exec('from torch.utils._sympy.functions import *', ...)`
# would work as long as the public API of that module is complete
"FloorDiv": torch.utils._sympy.functions.FloorDiv,
"ModularIndexing": torch.utils._sympy.functions.ModularIndexing,
"Where": torch.utils._sympy.functions.Where,
"PythonMod": torch.utils._sympy.functions.PythonMod,
"Mod": torch.utils._sympy.functions.Mod,
"CleanDiv": torch.utils._sympy.functions.CleanDiv,
"CeilToInt": torch.utils._sympy.functions.CeilToInt,
"FloorToInt": torch.utils._sympy.functions.FloorToInt,
"CeilDiv": torch.utils._sympy.functions.CeilDiv,
"LShift": torch.utils._sympy.functions.LShift,
"RShift": torch.utils._sympy.functions.RShift,
"PowByNatural": torch.utils._sympy.functions.PowByNatural,
"FloatPow": torch.utils._sympy.functions.FloatPow,
"FloatTrueDiv": torch.utils._sympy.functions.FloatTrueDiv,
"IntTrueDiv": torch.utils._sympy.functions.IntTrueDiv,
"IsNonOverlappingAndDenseIndicator": torch.utils._sympy.functions.IsNonOverlappingAndDenseIndicator,
"TruncToFloat": torch.utils._sympy.functions.TruncToFloat,
"TruncToInt": torch.utils._sympy.functions.TruncToInt,
"RoundToInt": torch.utils._sympy.functions.RoundToInt,
"RoundDecimal": torch.utils._sympy.functions.RoundDecimal,
"ToFloat": torch.utils._sympy.functions.ToFloat,
"Identity": torch.utils._sympy.functions.Identity,
}
self.symbol_name_to_symbol: dict[str, sympy.Symbol] = {}
self.constants = deserialize_torch_artifact(constants)
self.signature = self.deserialize_signature(
serialized_graph_module.signature
)
# deserialization does analysis with checks on 0/1, so we create fake range constraints and
# restore the original range constraints afterwards
self.symbol_name_to_range = {}
# we also need to bump unbacked sym[float,int] counters in the
# shape env to accommodate unbacked symbols in the exported program
self.unbacked_symbols = set()
count_unbacked_symfloat, count_unbacked_symint = -1, -1
unbacked_symfloat_prefix, unbacked_symint_prefix = (
prefix_str[t] for t in [SymT.UNBACKED_FLOAT, SymT.UNBACKED_INT]
)
if symbol_name_to_range:
for k, vr in symbol_name_to_range.items():
lower = vr.lower
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(
_int_to_sympy_int(lower, -int_oo), vr.upper
)
if k.startswith(unbacked_symfloat_prefix):
i = int(k[len(unbacked_symfloat_prefix) :])
count_unbacked_symfloat = max(count_unbacked_symfloat, i)
elif k.startswith(unbacked_symint_prefix):
i = int(k[len(unbacked_symint_prefix) :])
count_unbacked_symint = max(count_unbacked_symint, i)
# TODO(pianpwk): if we can clean up unused symbols in range_constraints,
# then this logic can just be handled with self.unbacked_symbols alone
for _ in range(count_unbacked_symfloat + 1):
self.shape_env.unbacked_symfloat_counter += 1
for _ in range(count_unbacked_symint + 1):
self.shape_env.unbacked_symint_counter += 1
if example_inputs is not None and len(example_inputs) > 0:
self.example_inputs = deserialize_torch_artifact(example_inputs)
else:
self.example_inputs = None
self.deserialize_graph(serialized_graph_module.graph)
with _enable_graph_inputs_of_type_nn_module(self.example_inputs):
module_call_graph = self.deserialize_module_call_graph(
serialized_graph_module.module_call_graph
)
graph_module = ep._create_graph_module_for_export(self.module, self.graph)
meta = {}
if custom := serialized_graph_module.metadata.get("custom"):
meta["custom"] = json.loads(custom)
if hasattr(serialized_graph_module, "treespec_namedtuple_fields"):
meta["treespec_namedtuple_fields"] = {}
for (
type_,
fields,
) in serialized_graph_module.treespec_namedtuple_fields.items():
meta["treespec_namedtuple_fields"][type_] = fields.field_names
graph_module.meta = meta
return GraphModuleDeserializer.Result(
graph_module=graph_module,
signature=self.signature,
module_call_graph=module_call_graph,
names_to_symbols=self.symbol_name_to_symbol,
state_dict=deserialize_torch_artifact(serialized_state_dict),
constants=self.constants,
example_inputs=self.example_inputs,
)
finally:
_CURRENT_DESERIALIZER = None
def sync_fx_node(self, name: str, fx_node: torch.fx.Node):
if name in self.serialized_name_to_node:
raise SerializeError(f"Node {name} has already been deserialized before.")
# overwrite name
fx_node.name = name
self.serialized_name_to_node[name] = fx_node
assert "val" not in fx_node.meta
fx_node.meta["val"] = self.serialized_name_to_meta[name]
def deserialize_sym_op_inputs(self, inputs):
return tuple(self.deserialize_input(input.arg) for input in inputs)
def deserialize_inputs(self, target, serialized_node: Node):
schema_args = _get_schema_from_target(target).arguments
argument_kinds = {input.name: input.kind for input in serialized_node.inputs}
actual_args = {
input.name: self.deserialize_input(input.arg)
for input in serialized_node.inputs
}
args = []
kwargs: OrderedDict[str, Any] = OrderedDict()
for schema_arg in schema_args:
if schema_arg.name in actual_args:
arg = actual_args[schema_arg.name]
kind = argument_kinds[schema_arg.name]
if kind == ArgumentKind.POSITIONAL:
args.append(arg)
continue
elif kind == ArgumentKind.KEYWORD and not keyword.iskeyword(
schema_arg.name
):
kwargs[schema_arg.name] = arg
continue
# If there's no ArgumentKind found, fallback to the old cases.
is_positional = (
not schema_arg.has_default_value() and not schema_arg.kwarg_only
)
if is_positional:
args.append(actual_args[schema_arg.name])
elif keyword.iskeyword(schema_arg.name):
assert not schema_arg.kwarg_only
if len(kwargs) > 0:
kwargs = OrderedDict()
args.extend(list(kwargs.values()))
args.append(actual_args[schema_arg.name])
else:
if schema_arg.name in actual_args:
kwargs[schema_arg.name] = actual_args[schema_arg.name]
return tuple(args), kwargs
def deserialize_hoo_inputs(self, inputs: list[NamedArgument]):
"""
For deserializing HOO inputs since HOOs do not have a schema.
"""
args = []
kwargs = {}
for input_ in inputs:
if input_.name != "":
kwargs[input_.name] = self.deserialize_input(input_.arg)
else:
args.append(self.deserialize_input(input_.arg))
return (tuple(args), kwargs)
def deserialize_input(self, inp: Argument) -> Any:
value = inp.value
typ_ = inp.type
if typ_ == "as_none":
# None should converted as None, but is encoded as bool in serialized
# Convert serialized object to torch equivalent
return None
elif typ_ == "as_tensor":
return self.serialized_name_to_node[inp.as_tensor.name]
elif typ_ == "as_scalar_type":
return _SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type]
elif typ_ == "as_memory_format":
return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format]
elif typ_ == "as_layout":
return _SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout]
elif typ_ == "as_graph":
assert isinstance(value, GraphArgument)
with self.save_graph_module():
self.deserialize_graph(value.graph)
submodule = ep._create_graph_module_for_export(self.module, self.graph)
self.module.register_module(value.name, submodule)
return self.graph.create_node(
"get_attr",
value.name,
name=value.name,
)
elif typ_ == "as_device":
return deserialize_device(inp.as_device)
elif typ_ == "as_int":
return inp.as_int
elif typ_ == "as_float":
return inp.as_float
elif typ_ == "as_bool":
return inp.as_bool
elif typ_ == "as_string":
return inp.as_string
elif typ_ == "as_complex":
return complex(inp.as_complex.real, inp.as_complex.imag)
elif typ_ == "as_sym_int":
return self.deserialize_sym_argument(inp.as_sym_int)
elif typ_ == "as_sym_float":
return self.deserialize_sym_argument(inp.as_sym_float)
elif typ_ == "as_sym_bool":
return self.deserialize_sym_argument(inp.as_sym_bool)
elif isinstance(value, list):
if len(value) == 0:
return []
elif typ_ == "as_tensors":
result = [self.serialized_name_to_node[arg.name] for arg in value]
return result
elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"):
# convert from serialized.python.types.List to python list
return list(value)
elif typ_ in ("as_sym_ints", "as_sym_bools", "as_sym_floats"):
return [self.deserialize_sym_argument(arg) for arg in value]
elif typ_ == "as_optional_tensors":
def deserialize_optional_tensor_args(a):
if a.type == "as_none":
return None
elif a.type == "as_tensor":
return self.serialized_name_to_node[a.value.name]
else:
raise SerializeError(f"Unhandled argument {inp}")
return list(map(deserialize_optional_tensor_args, value))
else:
raise SerializeError(f"Unhandled argument {inp}")
elif typ_ == "as_custom_obj":
if inp.as_custom_obj.name in self.serialized_name_to_node:
# Custom object has been lifted as an input
return self.serialized_name_to_node[inp.as_custom_obj.name]
return self.constants[inp.as_custom_obj.name]
elif typ_ == "as_operator":
return self.deserialize_operator(inp.as_operator)
else:
raise SerializeError(f"Unhandled argument {inp}")
def deserialize_constant_input(self, inp: ConstantValue) -> Any:
if inp.type == "as_int":
return int(inp.as_int)
elif inp.type == "as_float":
return float(inp.as_float)
elif inp.type == "as_string":
return str(inp.as_string)
elif inp.type == "as_bool":
return bool(inp.as_bool)
elif inp.type == "as_none":
return None
else:
raise SerializeError(f"Unhandled constant argument {inp} to deserialize")
def deserialize_sym_argument(self, sym_arg):
if isinstance(sym_arg, SymIntArgument):
if sym_arg.type == "as_int":
return sym_arg.as_int
elif sym_arg.type == "as_name":
return self.serialized_name_to_node[sym_arg.as_name]
elif isinstance(sym_arg, SymFloatArgument):
if sym_arg.type == "as_float":
return sym_arg.as_float
elif sym_arg.type == "as_name":
return self.serialized_name_to_node[sym_arg.as_name]
elif isinstance(sym_arg, SymBoolArgument):
if sym_arg.type == "as_bool":
return sym_arg.as_bool
elif sym_arg.type == "as_name":
return self.serialized_name_to_node[sym_arg.as_name]
raise SerializeError(f"Unknown symbolic argument type: {sym_arg}")
def deserialize_sym_op_outputs(self, serialized_node: Node, fx_node: torch.fx.Node):
self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node):
# Check single value return
if len(serialized_node.outputs) == 0:
return
if (
len(serialized_node.outputs) == 1
and "torch.ops.higher_order" in serialized_node.target
and not getattr(serialized_node, "is_hop_single_tensor_return", True)
and serialized_node.outputs[0].type != "as_none"
):
def _deserialize_hop_with_single_return(serialized_node, fx_node):
meta_val: list[Any] = []
arg = None
if serialized_node.outputs[0].type == "as_tensor":
arg = serialized_node.outputs[0].as_tensor
elif isinstance(
serialized_node.outputs[0].value,
(SymIntArgument, SymBoolArgument, SymFloatArgument),
):
arg = serialized_node.outputs[0].value
deserialized_metadata = self.deserialize_metadata(
serialized_node.metadata
)
assert arg is not None
# pyrefly: ignore # bad-argument-type
self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata)
fx_node.meta["val"] = tuple(meta_val)
self.serialized_name_to_node[fx_node.name] = fx_node
return
return _deserialize_hop_with_single_return(serialized_node, fx_node)
if (
len(serialized_node.outputs) == 1
and serialized_node.outputs[0].type == "as_tensor"
):
self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
return
elif len(serialized_node.outputs) == 1 and isinstance(
serialized_node.outputs[0].value,
(SymIntArgument, SymBoolArgument, SymFloatArgument),
):
self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
return
elif (
len(serialized_node.outputs) == 1
and serialized_node.outputs[0].type == "as_none"
):
# manually rename the node to a unused name to avoid naming conflicts
fx_node.meta["val"] = None
fx_node._rename(f"{self.graph._target_to_str(fx_node.target)}_unused")
return
self.deserialize_multiple_outputs(serialized_node, fx_node)
def generate_getitem(
self,
meta_val,
fx_node: torch.fx.Node,
arg: Union[TensorArgument, SymIntArgument, SymFloatArgument],
idx: int,
deserialized_metadata: dict[str, Any],
):
if isinstance(arg, TensorArgument):
name = arg.name
elif isinstance(arg, SymIntArgument):
name = arg.as_name
elif isinstance(arg, SymFloatArgument):
name = arg.as_name
else:
raise AssertionError(
f"generate_getitem got unknown argument type {type(arg)}"
)
individual_output = self.graph.create_node(
"call_function",
operator.getitem,
(fx_node, idx),
name=name,
)
self.sync_fx_node(name, individual_output)
meta_val.append(self.serialized_name_to_meta[name])
# The derived `getitem` nodes should have the same stacktrace as the
# original `fx_node`
individual_output.meta.update(deserialized_metadata)
def generate_getitems(
self,
meta_val,
fx_node: torch.fx.Node,
args,
deserialized_metadata: dict[str, Any],
):
for idx, arg in enumerate(args):
if isinstance(arg, (TensorArgument, SymIntArgument, SymFloatArgument)):
self.generate_getitem(
meta_val, fx_node, arg, idx, deserialized_metadata
)
continue
assert isinstance(arg, Argument)
if arg.type in ("as_tensor", "as_sym_int", "as_sym_float"):
self.generate_getitem(
meta_val, fx_node, arg.value, idx, deserialized_metadata
)
elif arg.type in (
"as_tensors",
"as_sym_ints",
"as_sym_floats",
"as_ints",
"as_floats",
"as_strings",
"as_bools",
"as_sym_bools",
):
list_output = self.graph.create_node(
"call_function",
operator.getitem,
(fx_node, idx),
)
meta_val.append([])
self.generate_getitems(
meta_val[-1], list_output, arg.value, deserialized_metadata
)
list_output.meta.update(deserialized_metadata)
list_output.meta["val"] = meta_val[-1]
elif arg.type == "as_none":
individual_output = self.graph.create_node(
"call_function",
operator.getitem,
(fx_node, idx),
name="as_none",
)
meta_val.append(None)
individual_output.meta["val"] = None
individual_output.meta.update(deserialized_metadata)
else:
raise NotImplementedError(f"Unimplemented node output type: {arg}")
def deserialize_multiple_outputs(
self, serialized_node: Node, fx_node: torch.fx.Node
) -> None:
deserialized_metadata = self.deserialize_metadata(serialized_node.metadata)
# Convert multiple return types to FX format.
# In FX, each node only returns one value. So in order to represent
# multiple return values, we have to emit a `getitem` node for each
# return value.
# This performs the inverse mapping of the `serialize_outputs` call in
# serialization, see [NOTE: Multiple outputs]
meta_val: list[Any] = []
if len(serialized_node.outputs) == 1:
assert isinstance(serialized_node.outputs[0].value, list)
assert isinstance(serialized_node.outputs[0].value[0], TensorArgument)
self.generate_getitems(
meta_val,
fx_node,
serialized_node.outputs[0].as_tensors,
deserialized_metadata,
)
else:
self.generate_getitems(
meta_val, fx_node, serialized_node.outputs, deserialized_metadata
)
# also update the metaval for `fx_node` to be a list(meta)
fx_node.meta["val"] = tuple(meta_val)
self.serialized_name_to_node[fx_node.name] = fx_node
def deserialize_metadata(self, metadata: dict[str, str]) -> dict[str, Any]:
ret: dict[str, Any] = {}
if stack_trace := metadata.get("stack_trace"):
ret["stack_trace"] = stack_trace
def deserialize_meta_func(serialized_target: str):
module = None
if serialized_target.startswith("torch.nn"):
module = torch.nn
serialized_target_names = serialized_target.split(".")[2:]
elif serialized_target.startswith("torch"):
module = torch
serialized_target_names = serialized_target.split(".")[1:]
else:
return self.deserialize_operator(serialized_target)
target = module
for name in serialized_target_names:
if not hasattr(target, name):
return serialized_target
else:
target = getattr(target, name)
return target
if nn_module_stack_str := metadata.get("nn_module_stack"):
# Originally serialized to "key,orig_path,type_str"
def import_nn_module_stack(key, path, ty):
return key, (path, ty)
# Helper function to split string by commas, accounting for nested parentheses/brackets
def metadata_split(metadata):
out = []
start, n = 0, 0
a, b = "[(", ")]"
for end, c in enumerate(metadata):
if c in a:
n += 1
elif c in b:
n -= 1
elif c == "," and n == 0:
out.append(metadata[start:end])
start = end + 1
out.append(metadata[start:])
assert len(out) == 3
return out
nn_module_stack = dict(
import_nn_module_stack(*metadata_split(item))
for item in nn_module_stack_str.split(ST_DELIMITER)
)
ret["nn_module_stack"] = nn_module_stack
if source_fn_st_str := metadata.get("source_fn_stack"):
# Originally serializes to "fx_node_name,op_str"
source_fn_st = []
for source_fn_str in source_fn_st_str.split(ST_DELIMITER):
name, target_str = source_fn_str.split(",")
source_fn_st.append((name, deserialize_meta_func(target_str)))
ret["source_fn_stack"] = source_fn_st
if torch_fn_str := metadata.get("torch_fn"):
ret["torch_fn"] = tuple(torch_fn_str.split(ST_DELIMITER))
if custom_str := metadata.get("custom"):
ret["custom"] = json.loads(custom_str)
return ret
def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec:
log.debug("[deserialize_argument_spec] %s", x)
if x.type == "as_tensor":
return ep.TensorArgument(name=x.as_tensor.name)
elif x.type == "as_sym_int":
return ep.SymIntArgument(name=x.as_sym_int.as_name)
elif x.type == "as_sym_float":
return ep.SymFloatArgument(name=x.as_sym_float.as_name)
elif x.type == "as_custom_obj":
return ep.ConstantArgument(
name=x.as_custom_obj.name, value=self.deserialize_input(x)
)
else:
return ep.ConstantArgument(name="", value=self.deserialize_input(x))
def deserialize_module_call_signature(
self, module_call_signature: ModuleCallSignature
) -> ep.ModuleCallSignature:
return ep.ModuleCallSignature(
inputs=[
self.deserialize_argument_spec(x) for x in module_call_signature.inputs
],
outputs=[
self.deserialize_argument_spec(x) for x in module_call_signature.outputs
],
in_spec=treespec_loads(module_call_signature.in_spec),
out_spec=treespec_loads(module_call_signature.out_spec),
forward_arg_names=names
if (names := module_call_signature.forward_arg_names)
else None,
)
def deserialize_module_call_graph(
self, module_call_graph: list[ModuleCallEntry]
) -> list[ep.ModuleCallEntry]:
log.debug("\n[deserialize_module_call_graph]")
return [
ep.ModuleCallEntry(
fqn=entry.fqn,
signature=(
self.deserialize_module_call_signature(entry.signature)
if entry.signature
else None
),
)
for entry in module_call_graph
]
@final
class ExportedProgramDeserializer(metaclass=Final):
def __init__(self, expected_opset_version: Optional[dict[str, int]] = None):
self.expected_opset_version: dict[str, int] = {}
if expected_opset_version:
self.expected_opset_version.update(expected_opset_version)
if "aten" not in self.expected_opset_version:
self.expected_opset_version["aten"] = torch._C._get_max_operator_version()
def deserialize_range_constraints(
self,
symbol_name_to_range: dict[str, symbolic_shapes.ValueRanges],
symbol_name_to_symbol: dict[str, sympy.Symbol],
) -> dict[sympy.Symbol, ValueRanges]:
log.debug("\n[deserialize_range_constraints]")
range_constraints = {}
for k, v in symbol_name_to_range.items():
if symbol := symbol_name_to_symbol.get(k):
log.debug("[deserialize_range_constraints] %s -> %s", k, v)
range_constraints[symbol] = v # type: ignore[arg-type]
else:
log.warning(
"Symbol %s did not appear in the graph that was deserialized", k
)
return range_constraints
def deserialize(
self,
exported_program: ExportedProgram,
state_dict: Union[dict[str, torch.Tensor], bytes],
constants: Union[dict[str, torch.Tensor], bytes],
example_inputs: Optional[
Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes]
] = None,
*,
_unsafe_skip_version_check=False,
) -> ep.ExportedProgram:
assert isinstance(exported_program, ExportedProgram)
version = exported_program.schema_version
# TODO(zhxchen17) blocked on thrift schema refactor
if version.major != SCHEMA_VERSION[0] and not (
version.major == 0 and version.minor == 0
):
if not _unsafe_skip_version_check:
raise SerializeError(
f"Serialized schema version {exported_program.schema_version} "
f"does not match our current schema version {SCHEMA_VERSION}."
)
symbol_name_to_range = {
k: symbolic_shapes.ValueRanges(
_int_to_sympy_int(v.min_val, -int_oo),
_int_to_sympy_int(v.max_val, int_oo),
)
for k, v in exported_program.range_constraints.items()
}
res = GraphModuleDeserializer().deserialize(
exported_program.graph_module,
state_dict,
constants,
example_inputs,
symbol_name_to_range,
)
range_constraints = self.deserialize_range_constraints(
symbol_name_to_range,
res.names_to_symbols,
)
result = ep.ExportedProgram(
root=res.graph_module,
graph=res.graph_module.graph,
graph_signature=res.signature,
state_dict=res.state_dict, # type: ignore[arg-type]
range_constraints=range_constraints,
module_call_graph=res.module_call_graph,
example_inputs=res.example_inputs,
constants=res.constants,
verifiers=[load_verifier(v) for v in exported_program.verifiers],
)
result._guards_code = exported_program.guards_code
log.debug("\n[deserialize]: %s", result)
return result
class EnumEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, bytes):
return base64.b64encode(obj).decode("utf-8")
return super().default(obj)
def _dataclass_to_dict(obj):
if isinstance(obj, _Union):
return {obj.type: _dataclass_to_dict(obj.value)}
elif dataclasses.is_dataclass(obj):
return {
f.name: _dataclass_to_dict(getattr(obj, f.name))
for f in dataclasses.fields(obj)
}
elif isinstance(obj, list):
return [_dataclass_to_dict(x) for x in obj]
elif isinstance(obj, tuple):
return tuple(_dataclass_to_dict(x) for x in obj)
elif isinstance(obj, dict):
return {k: _dataclass_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, float):
if obj == math.inf:
return "Infinity"
elif obj == -math.inf:
return "-Infinity"
elif math.isnan(obj):
return "NaN"
else:
return obj
else:
return obj
def _to_json_bytes(obj: Any) -> bytes:
return json.dumps(_dataclass_to_dict(obj), cls=EnumEncoder, allow_nan=False).encode(
"utf-8"
)
def serialize(
exported_program: ep.ExportedProgram,
opset_version: Optional[dict[str, int]] = None,
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
) -> SerializedArtifact:
with _enable_graph_inputs_of_type_nn_module(exported_program.example_inputs):
serialized_program = ExportedProgramSerializer(
opset_version, pickle_protocol
).serialize(exported_program)
assert isinstance(serialized_program.exported_program, ExportedProgram)
json_bytes = _to_json_bytes(serialized_program.exported_program)
artifact = SerializedArtifact(
json_bytes,
serialized_program.state_dict,
serialized_program.constants,
serialized_program.example_inputs,
)
return artifact
def _dict_to_dataclass(cls, data):
assert not isinstance(cls, str), f"Unresolved class type: '{cls}'."
if typing.get_origin(cls) == Annotated:
return _dict_to_dataclass(cls.__origin__, data)
if typing.get_origin(cls) == typing.Union and type(None) in typing.get_args(cls):
if data is None:
return None
ty_args = typing.get_args(cls)
assert len(ty_args) == 2
return _dict_to_dataclass(ty_args[0], data)
elif isinstance(cls, type) and issubclass(cls, _Union):
assert isinstance(data, dict)
assert len(data) == 1
_type = next(iter(data.keys()))
_value = next(iter(data.values()))
assert isinstance(_type, str)
field_type = cls.__annotations__[_type]
# pyrefly: ignore # missing-attribute
return cls.create(**{_type: _dict_to_dataclass(field_type, _value)})
elif dataclasses.is_dataclass(cls):
fields = {}
type_hints = typing.get_type_hints(cls)
# For forward compatibility consideration, we ignore all the keys
# that are not showing up in the dataclass definition.
for f in dataclasses.fields(cls):
name = f.name
if name not in data:
continue
new_field_obj = _dict_to_dataclass(type_hints[name], data[name])
fields[name] = new_field_obj
return cls(**fields) # type: ignore[operator]
elif isinstance(data, list):
if len(data) == 0:
return data
d_type = typing.get_args(cls)[0]
return [_dict_to_dataclass(d_type, d) for d in data]
elif isinstance(data, dict):
v_type = typing.get_args(cls)[1]
return {k: _dict_to_dataclass(v_type, v) for k, v in data.items()}
elif cls == float:
return float(data)
return data
def _bytes_to_dataclass(cls: Any, artifact_bytes: bytes) -> Any:
artifact_str = artifact_bytes.decode("utf-8")
artifact_dict = json.loads(artifact_str)
artifact_dataclass = _dict_to_dataclass(cls, artifact_dict)
return artifact_dataclass
def deserialize(
artifact: SerializedArtifact,
expected_opset_version: Optional[dict[str, int]] = None,
*,
_unsafe_skip_version_check=False,
) -> ep.ExportedProgram:
assert isinstance(artifact.exported_program, bytes)
serialized_exported_program = _bytes_to_dataclass(
ExportedProgram, artifact.exported_program
)
return ExportedProgramDeserializer(expected_opset_version).deserialize(
serialized_exported_program,
artifact.state_dict,
artifact.constants,
artifact.example_inputs,
_unsafe_skip_version_check=_unsafe_skip_version_check,
)
def _canonicalize_graph(
sorted_inputs, sorted_outputs, graph, constants
) -> tuple[Graph, dict[str, str]]:
def _get_argument(a: Argument):
if a.type == "as_none":
return None
elif a.type == "as_tensor":
return a.as_tensor
elif a.type == "as_tensors":
return a.as_tensors
elif a.type == "as_int":
return None
elif a.type == "as_ints":
return None
elif a.type == "as_float":
return None
elif a.type == "as_floats":
return None
elif a.type == "as_string":
return None
elif a.type == "as_strings":
return None
elif a.type == "as_complex":
return None
elif a.type == "as_sym_int":
return a.as_sym_int
elif a.type == "as_sym_ints":
return a.as_sym_ints
elif a.type == "as_sym_float":
return a.as_sym_float
elif a.type == "as_sym_floats":
return a.as_sym_floats
elif a.type == "as_scalar_type":
return None
elif a.type == "as_memory_format":
return None
elif a.type == "as_layout":
return None
elif a.type == "as_device":
return None
elif a.type == "as_bool":
return None
elif a.type == "as_bools":
return None
elif a.type == "as_sym_bool":
return a.as_sym_bool
elif a.type == "as_sym_bools":
return a.as_sym_bools
elif a.type == "as_graph":
return None
elif a.type == "as_optional_tensors":
return a.as_optional_tensors
elif a.type == "as_custom_obj":
return a.as_custom_obj
elif a.type == "as_operator":
return None
else:
raise AssertionError(f"Unknown input type to the ExportedProgram: {a}")
# Stage 1: Reorder named items.
def for_args(f, a):
assert isinstance(a, Argument)
pytree.tree_map(f, _get_argument(a))
def sort_nodes(nodes):
@dataclass
class Edges:
outs: list[int]
ins: int
graph_inputs: set[str] = set()
def_table: dict[str, int] = {}
edges: dict[int, Edges] = {}
candidates: list[tuple[str, list[tuple[str, list[int]]], int]] = []
rank: dict[str, int] = {}
ret: list[Node] = []
def get_name(a) -> Optional[str]:
if a is None:
return None
if isinstance(a, TensorArgument):
return a.name
elif isinstance(a, (SymIntArgument, SymBoolArgument, SymFloatArgument)):
if a.type == "as_name":
return a.as_name
elif a.type in ("as_int", "as_bool", "as_float"):
return None
else:
raise AssertionError(f"Unknown argument type: {a}")
elif isinstance(a, OptionalTensorArgument):
if a.type == "as_tensor":
return a.as_tensor.name
elif a.type == "as_none":
return None
else:
raise AssertionError(f"Unknown optional tensor type: {a}")
elif isinstance(a, CustomObjArgument):
return a.name
else:
raise AssertionError(f"Unknown argument type: {a}")
for i in sorted_inputs:
def add_input(a):
if s := get_name(a):
graph_inputs.add(s)
for_args(add_input, i)
for idx, node in enumerate(nodes):
def add_def(a):
if s := get_name(a):
assert s not in def_table
def_table[s] = idx
for o in node.outputs:
for_args(add_def, o)
edges[idx] = Edges([], 0)
for idx, user in enumerate(nodes):
def add_edge(a):
if s := get_name(a):
if s in constants:
return
if s not in def_table:
assert s in graph_inputs
return
src = def_table[s]
edges[src].outs.append(idx)
edges[idx].ins += 1
for i in user.inputs:
for_args(add_edge, i.arg)
def add_rank(a):
if s := get_name(a):
assert s not in rank
rank[s] = len(rank)
def get_rank(a):
s = get_name(a)
if s and s not in constants:
return rank[s]
else:
return -1
for i in sorted_inputs:
for_args(add_rank, i)
def add_candidate(idx: int):
def get_ranks(i):
ranks = []
for_args(lambda x: ranks.append(get_rank(x)), i)
return ranks
node = nodes[idx]
args_rank = [(a.name, get_ranks(a.arg)) for a in node.inputs]
heapq.heappush(candidates, (node.target, args_rank, idx))
for idx, e in edges.items():
if e.ins == 0:
add_candidate(idx)
while len(candidates) > 0:
_, _, idx = heapq.heappop(candidates)
node = nodes[idx]
for o in node.outputs:
for_args(add_rank, o)
ret.append(node)
assert idx in edges
for user in edges[idx].outs:
e = edges[user]
assert e.ins > 0
e.ins -= 1
if e.ins == 0:
add_candidate(user)
edges[idx].outs.clear()
return ret
sorted_nodes = sort_nodes(graph.nodes)
assert len(sorted_nodes) == len(graph.nodes)
# Stage 2: Rename nodes.
name_table: dict[str, str] = {}
def rename_def(a):
def _rename(arg_name, values):
new_name = f"_{len(name_table)}"
assert arg_name not in name_table
name_table[arg_name] = new_name
assert arg_name in values
values[new_name] = values.pop(arg_name)
return new_name
if a is None:
return
if isinstance(a, TensorArgument):
a.name = _rename(a.name, graph.tensor_values)
elif isinstance(a, SymIntArgument):
if a.type == "as_name":
a.as_name = _rename(a.as_name, graph.sym_int_values)
elif isinstance(a, SymFloatArgument):
if a.type == "as_name":
a.as_name = _rename(a.as_name, graph.sym_float_values)
elif isinstance(a, SymBoolArgument):
if a.type == "as_name":
a.as_name = _rename(a.as_name, graph.sym_bool_values)
elif isinstance(a, CustomObjArgument):
a.name = _rename(a.name, graph.custom_obj_values)
else:
raise AssertionError(f"Unknown argument type: {a}")
def replace_use(a):
if a is None:
return
if isinstance(a, TensorArgument):
a.name = name_table.get(a.name, a.name)
elif isinstance(a, (SymIntArgument, SymFloatArgument)):
if a.type == "as_name":
a.as_name = name_table.get(a.as_name, a.as_name)
elif isinstance(a, SymBoolArgument):
if a.type == "as_name":
a.as_name = name_table.get(a.as_name, a.as_name)
elif isinstance(a, OptionalTensorArgument):
if a.type == "as_tensor":
a.as_tensor.name = name_table.get(a.as_tensor.name, a.as_tensor.name)
elif isinstance(a, CustomObjArgument):
a.name = name_table.get(a.name, a.name)
else:
raise AssertionError(f"Unknown argument type: {a}")
for i in sorted_inputs:
for_args(rename_def, i)
for n in sorted_nodes:
for o in n.outputs:
for_args(rename_def, o)
for n in sorted_nodes:
for i in n.inputs:
for_args(replace_use, i.arg)
for o in sorted_outputs:
for_args(replace_use, o)
# Stage 3: Remove unstable fields.
for n in sorted_nodes:
n.metadata.clear()
# Stage 4: Aggregate values.
# pyrefly: ignore # no-matching-overload
sorted_tensor_values = dict(
sorted(graph.tensor_values.items(), key=operator.itemgetter(0))
)
# pyrefly: ignore # no-matching-overload
sorted_sym_int_values = dict(
sorted(graph.sym_int_values.items(), key=operator.itemgetter(0))
)
# pyrefly: ignore # no-matching-overload
sorted_sym_float_values = dict(
sorted(graph.sym_float_values.items(), key=operator.itemgetter(0))
)
# pyrefly: ignore # no-matching-overload
sorted_sym_bool_values = dict(
sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0))
)
# pyrefly: ignore # no-matching-overload
sorted_custom_obj_values = dict(
sorted(graph.custom_obj_values.items(), key=operator.itemgetter(0))
)
# Stage 5: Recurse in subgraphs.
counter = 0
for node in sorted_nodes:
for i in node.inputs:
a = i.arg
if a.type == "as_graph":
a.as_graph.graph, _ = _canonicalize_graph(
a.as_graph.graph.inputs,
a.as_graph.graph.outputs,
a.as_graph.graph,
constants,
)
a.as_graph.name = f"_g{counter}"
counter += 1
graph = Graph(
inputs=sorted_inputs,
outputs=sorted_outputs,
nodes=sorted_nodes,
tensor_values=sorted_tensor_values,
sym_int_values=sorted_sym_int_values,
sym_float_values=sorted_sym_float_values,
sym_bool_values=sorted_sym_bool_values,
is_single_tensor_return=graph.is_single_tensor_return,
custom_obj_values=sorted_custom_obj_values,
)
return graph, name_table
def canonicalize(
ep: ExportedProgram, constants: Optional[set[str]] = None
) -> ExportedProgram:
"""
Normalize a serialized ExportedProgram, so that different eager program which
shares the same semantics can get a single representation on disk.
This function canonicalizes an ExportedProgram by:
1. Sorting nodes in topological order.
2. Rename nodes to have unique names.
3. Remove unstable fields.
4. Aggregate the above program fields.
5. Recurse in subgraphs.
Args:
ep (ExportedProgram): The ExportedProgram to canonicalize.
constants (Optional[set[str]]): Set of constants names
Returns:
ExportedProgram: The canonicalized exported program.
"""
ep = copy.deepcopy(ep)
# pyrefly: ignore # annotation-mismatch
constants: set[str] = constants or set()
opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0)))
range_constraints = dict(
sorted(ep.range_constraints.items(), key=operator.itemgetter(0))
)
guards_code = sorted(ep.guards_code)
module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn)
signature = ep.graph_module.signature
graph = ep.graph_module.graph
assert len(graph.inputs) == len(signature.input_specs)
assert len(graph.outputs) == len(signature.output_specs)
def rank_input(inp) -> tuple[int, Optional[str], int]:
idx, (_arg, spec) = inp
assert isinstance(spec, InputSpec)
if spec.type == "user_input":
return 5, None, idx
elif spec.type == "parameter":
return 1, spec.parameter.parameter_name, idx
elif spec.type == "buffer":
return 2, spec.buffer.buffer_name, idx
elif spec.type == "tensor_constant":
return 3, spec.tensor_constant.tensor_constant_name, idx
elif spec.type == "custom_obj":
return 4, spec.custom_obj.custom_obj_name, idx
elif spec.type == "token":
return 0, None, idx
elif spec.type == "constant_input":
return 6, spec.constant_input.name, idx
else:
raise AssertionError(f"Unknown input type: {spec}")
def rank_output(out) -> tuple[int, Optional[str], int]:
idx, (_arg, spec) = out
assert isinstance(spec, OutputSpec)
if spec.type == "user_output":
return 4, None, idx
elif spec.type == "loss_output":
return 4, None, idx
elif spec.type == "parameter_mutation":
return 1, spec.parameter_mutation.parameter_name, idx
elif spec.type == "buffer_mutation":
return 2, spec.buffer_mutation.buffer_name, idx
elif spec.type == "gradient_to_parameter":
return 5, spec.gradient_to_parameter.parameter_name, idx
elif spec.type == "gradient_to_user_input":
return 6, None, idx
elif spec.type == "user_input_mutation":
return 3, None, idx
elif spec.type == "token":
return 0, None, idx
else:
raise AssertionError(f"Unknown output type: {spec}")
sorted_ins = sorted(
enumerate(zip(graph.inputs, signature.input_specs)), key=rank_input
)
if len(sorted_ins) > 0:
sorted_inputs, input_specs = zip(*(i for idx, i in sorted_ins)) # type: ignore[assignment]
else:
sorted_inputs = ()
input_specs = ()
sorted_outs = sorted(
enumerate(zip(graph.outputs, signature.output_specs)), key=rank_output
)
sorted_outputs, output_specs = zip(*(i for idx, i in sorted_outs)) # type: ignore[assignment]
sorted_graph, replace_table = _canonicalize_graph(
sorted_inputs, sorted_outputs, graph, constants
)
def replace_input(spec):
assert isinstance(spec, InputSpec)
if spec.type == "user_input":
arg = spec.user_input.arg
if arg.type == "as_tensor":
t = arg.as_tensor
t.name = replace_table[t.name]
elif arg.type == "as_sym_int":
s = arg.as_sym_int
if s.type == "as_name":
s.as_name = replace_table[s.as_name]
elif s.type == "as_int":
pass
else:
raise AssertionError(f"Unknown sym_int type: {s}")
elif arg.type == "as_sym_float":
f = arg.as_sym_float
if f.type == "as_name":
f.as_name = replace_table[f.as_name]
elif f.type == "as_float":
pass
else:
raise AssertionError(f"Unknown sym_float type: {f}")
elif arg.type in (
"as_none",
"as_bool",
"as_int",
"as_float",
"as_string",
"as_custom_obj",
):
return
else:
raise AssertionError(f"Unknown input type: {arg}")
elif spec.type == "parameter":
t = spec.parameter.arg
t.name = replace_table[t.name]
elif spec.type == "buffer":
t = spec.buffer.arg
t.name = replace_table[t.name]
elif spec.type == "tensor_constant":
t = spec.tensor_constant.arg
t.name = replace_table[t.name]
elif spec.type == "custom_obj":
t_custom_obj = spec.custom_obj.arg
t_custom_obj.name = replace_table[t_custom_obj.name]
return
elif spec.type == "token":
tok = spec.token.arg
tok.name = replace_table[tok.name]
elif spec.type == "constant_input":
return
else:
raise AssertionError(f"Unknown input type: {spec}")
def replace_output(out):
assert isinstance(spec, OutputSpec)
if spec.type == "user_output":
arg = spec.user_output.arg
if arg.type == "as_tensor":
t = arg.as_tensor
t.name = replace_table[t.name]
elif arg.type == "as_sym_int":
s = arg.as_sym_int
if s.type == "as_name":
s.as_name = replace_table[s.as_name]
elif s.type == "as_int":
pass
else:
raise AssertionError(f"Unknown sym_int type: {s}")
elif arg.type == "as_sym_float":
f = arg.as_sym_float
if f.type == "as_name":
f.as_name = replace_table[f.as_name]
elif f.type == "as_float":
pass
else:
raise AssertionError(f"Unknown sym_float type: {f}")
elif arg.type in ("as_none", "as_bool", "as_int", "as_float", "as_string"):
return
else:
raise AssertionError(f"Unknown input type: {arg}")
elif spec.type == "loss_output":
t = spec.loss_output.arg
t.name = replace_table[t.name]
elif spec.type == "buffer_mutation":
t = spec.buffer_mutation.arg
t.name = replace_table[t.name]
elif spec.type == "parameter_mutation":
t = spec.parameter_mutation.arg
t.name = replace_table[t.name]
elif spec.type == "gradient_to_parameter":
t = spec.gradient_to_parameter.arg
t.name = replace_table[t.name]
elif spec.type == "gradient_to_user_input":
g = spec.gradient_to_user_input
g.arg.name = replace_table[g.arg.name]
g.user_input_name = replace_table[g.user_input_name]
elif spec.type == "user_input_mutation":
u = spec.user_input_mutation
u.arg.name = replace_table[u.arg.name]
u.user_input_name = replace_table[u.user_input_name]
elif spec.type == "token":
tok = spec.token.arg
tok.name = replace_table[tok.name]
else:
raise AssertionError(f"Unknown output type: {spec}")
for spec in input_specs:
replace_input(spec)
for spec in output_specs:
replace_output(spec)
return ExportedProgram(
graph_module=GraphModule(
graph=sorted_graph,
signature=GraphSignature(
input_specs=list(input_specs),
output_specs=list(output_specs),
),
module_call_graph=module_call_graph,
),
opset_version=opset_version,
range_constraints=range_constraints,
schema_version=ep.schema_version,
verifiers=ep.verifiers,
torch_version=ep.torch_version,
guards_code=guards_code,
)
class ExtensionHandler:
"""
Base class for handling extension operators.
"""
@classmethod
def namespace(cls) -> str:
raise NotImplementedError(f"{cls.__class__} namespace() must be implemented")
@classmethod
def to_op_name(cls, op) -> str:
raise NotImplementedError(f"{cls.__class__} op_name() must be implemented")
@classmethod
def from_op_name(cls, name: str):
raise NotImplementedError(f"{cls.__class__} op_name() must be implemented")
@classmethod
def op_schema(cls, op) -> torch.FunctionSchema:
raise NotImplementedError(f"{cls.__class__} op_schema() must be implemented")
def register_extension(
op_type: type[Any],
extension_handler: type[ExtensionHandler],
):
"""Register custom de/serialization method for a node with non-standard type."""
assert issubclass(extension_handler, ExtensionHandler), (
f"Expected ExtensionHandler, got {extension_handler}."
)
assert op_type not in _serialization_registry, f"{op_type} is already registered."
assert isinstance(op_type, type) # Maybe a good idea to enforce this first.
assert not (
op_type.__module__.startswith("torch")
or op_type.__module__.startswith("builtins")
)
assert extension_handler.namespace() not in _deserialization_registry
_serialization_registry[op_type] = extension_handler
_deserialization_registry[extension_handler.namespace()] = extension_handler
def _registered_extension_types():
return tuple(_serialization_registry.keys())
# Registry to store all custom serialization implementations.
# The registry maps a operation to its serialization function (a callable), in their own
# namespace to avoid conflicts.
# Serialization: Op type --> custom handler.
# De-serialization: Namespace --> custom handler.
_serialization_registry: dict[type[Any], type[ExtensionHandler]] = {}
_deserialization_registry: dict[str, type[ExtensionHandler]] = {}