Files
pytorch/torch/export/exported_program.py
Pian Pawakapan 745324e487 [export] turn on hybrid symints by default (#130775)
Sets `prefer_deferred_runtime_asserts_over_guards=True` for export, so any guards emitted from `SymNode.expect_true` (for example, guards that are implicitly required to be true for an op to succeed) won't lead to constraint violations. Instead these should appear in the graph as runtime asserts, or potentially as replacement expressions for placeholder shapes.

For example, this reshape op should emit s0 * s1 = s2, deferred as a runtime assert.
```
x = torch.randn(4, 8)  # [s0, s1]
y = torch.randn(32)  # [s2]
out = x.reshape(-1) + y
# this emits Eq(s0 * s1, s2), and we represent y's shape as [s0*s1] in the graph.
```

However, other complex guards can still cause export to fail, for instance guards emitted from `SymNode.guard_bool/guard_size_oblivious` (e.g. explicit if-else conditions in user code or lower-level op implementations hit during tracing) can still raise constraint violations. These can be deferred with `allow_complex_guards_as_runtime_asserts=True`. We don't yet make this default, because while this makes export more likely to succeed, it results in non-trivial asserts being emitted that often represent specialization to a variant of the op, or checks related to 0/1 specialization.

We also remove forced specializations for export and kill the `_disable_forced_specializations` flag - now any guard we can't express with Dims/DerivedDims either are handled with Hybrid SymInts, or should be resolved with rewriting or deferring.

Follow up:
Currently, `ShapeEnv._set_replacement()` is called for complex equality expressions (e.g. s2 -> s0*s1 in the example above), and the ExportedProgram stores `s0*s1` in the input placeholder. This isn't checked for validity when the program is run, so an option is to avoid replacement and/or runtime assert on equality.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130775
Approved by: https://github.com/avikchaudhuri
2024-07-18 17:40:58 +00:00

1228 lines
45 KiB
Python

# mypy: allow-untyped-defs
import contextlib
import copy
import dataclasses
import functools
import re
import types
import warnings
from collections import namedtuple
from contextlib import contextmanager
from typing import (
Any,
Callable,
Dict,
final,
Iterator,
List,
Optional,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._library.fake_class_registry import FakeScriptObject
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.fx.immutable_collections import immutable_dict, immutable_list
if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code.
import sympy
from torch.utils._sympy.value_ranges import ValueRanges
import torch
import torch.utils._pytree as pytree
from torch._export.verifier import Verifier
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.export._tree_utils import is_equivalent, reorder_kwargs
from torch.fx._compatibility import compatibility
from torch.fx._utils import first_call_function_nn_module_stack
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
from .graph_signature import ( # noqa: F401
_sig_to_specs,
ArgumentSpec,
ConstantArgument,
CustomObjArgument,
ExportGraphSignature,
InputKind,
InputSpec,
OutputKind,
OutputSpec,
SymIntArgument,
TensorArgument,
TokenArgument,
)
__all__ = [
"ExportedProgram",
"ModuleCallEntry",
"ModuleCallSignature",
]
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
@dataclasses.dataclass
class ModuleCallSignature:
inputs: List[ArgumentSpec]
outputs: List[ArgumentSpec]
in_spec: pytree.TreeSpec
out_spec: pytree.TreeSpec
@dataclasses.dataclass
class ModuleCallEntry:
fqn: str
signature: Optional[ModuleCallSignature] = None
def _disable_prexisiting_fake_mode(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
with maybe_disable_fake_tensor_mode():
return fn(*args, **kwargs)
return wrapper
def _fx_collection_equivalence_fn(
spec1_type: Optional[type],
spec1_context: pytree.Context,
spec2_type: Optional[type],
spec2_context: pytree.Context,
) -> bool:
"""Treat containers and their immutable variants as the same type. Otherwise
compare as normal.
"""
if spec1_type is None or spec2_type is None:
return spec1_type is spec2_type and spec1_context == spec2_context
if issubclass(spec1_type, (dict, immutable_dict)) and issubclass(
spec2_type, (dict, immutable_dict)
):
return spec1_context == spec2_context
if issubclass(spec1_type, (list, immutable_list)) and issubclass(
spec2_type, (list, immutable_list)
):
return spec1_context == spec2_context
return spec1_type is spec2_type and spec1_context == spec2_context
def _register_cia_to_meta(*args, **kwargs):
kernel = kwargs["kernel"]
del kwargs["kernel"]
assert torch._C._dispatch_has_kernel_for_dispatch_key(
kernel.name(), torch._C.DispatchKey.CompositeImplicitAutograd
)
return kernel._op_dk(
torch._C.DispatchKey.CompositeImplicitAutograd, *args, **kwargs
)
# This list is compiled from DispatchKey.cpp.
# The idea is that we use these keys to override
# CIA decomp in export
_AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE = [
torch._C.DispatchKey.AutogradCPU,
torch._C.DispatchKey.AutogradCUDA,
torch._C.DispatchKey.AutogradMeta,
torch._C.DispatchKey.AutogradXLA,
torch._C.DispatchKey.AutogradLazy,
torch._C.DispatchKey.AutogradIPU,
torch._C.DispatchKey.AutogradXPU,
torch._C.DispatchKey.AutogradMPS,
torch._C.DispatchKey.AutogradHPU,
torch._C.DispatchKey.AutogradPrivateUse1,
torch._C.DispatchKey.AutogradPrivateUse2,
torch._C.DispatchKey.AutogradPrivateUse3,
]
@contextmanager
def _override_composite_implicit_decomp(ops_to_preserve, decomp_table):
# This function overrides CompositeImplicitAutograd decomp for
# functional composite ops that user specified. Ideally we want to not-decompose
# ALL composite ops but today's C++ functinalization relies on
# the fact that it is working with the opset after decomp is run.
# Hence we can only do it for functional ops. One caveat is that
# there are some composite ops that lie about their schema (claimed to be
# functional but not really aka dropout), for these cases, we just decompose.
saved_tables = {}
patched_ops = set()
removed_decomps = {}
for op_overload in ops_to_preserve:
# Our strategy for deciding if we can preserve CIA is following:
# 1. The op should be known statically that it is functional
# 2. If it is maybe aliasing, we decompose because we must know if an op
# is mutating or aliasing.
# TODO (tmanlaibaatar) make this utility function and share it with functional_tensor
# decomp part. (https://github.com/pytorch/pytorch/issues/129431)
def assert_valid_to_preserve(op_overload):
if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops:
raise RuntimeError(
f"We can't detect {op_overload} as a functional op statically, so we can't preserve it"
)
if op_overload in FunctionalTensor.metadata_fns:
raise RuntimeError(
f"{op_overload} is a metadata query function, "
"it will be preserved implicitly in our tracing system. "
"Please file an issue on github if you see otherwise"
)
alias_info = len(
[i for i in op_overload._schema.arguments if i.alias_info is not None]
)
is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable
if is_mutating_or_aliasing:
raise RuntimeError(
f"{op_overload} is a mutating/aliasing op, we can't preserve it as is"
)
if not torch._C._dispatch_has_kernel(op_overload.name()):
raise RuntimeError(
f"{op_overload} is a TorchScript op, we can't preserve it as is"
)
if not torch._C._dispatch_has_kernel_for_dispatch_key(
op_overload.name(), torch._C.DispatchKey.CompositeImplicitAutograd
):
raise RuntimeError(
f"{op_overload} is not CompositeImplicitAutograd op, so we will preserve "
"it as long as there is no python decomposition"
)
return True
# If we didn't error, it means we can go ahead
assert_valid_to_preserve(op_overload)
saved_tables[op_overload] = op_overload.py_kernels.copy()
patched_ops.add(op_overload)
for override_dispatch_key in _AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE:
if override_dispatch_key not in op_overload.py_kernels:
# TODO (tmanlaibaatar)https://github.com/pytorch/pytorch/issues/129430
op_overload.py_impl(override_dispatch_key)(
autograd_not_implemented(op_overload, deferred_error=True)
)
if torch._C.DispatchKey.CompositeImplicitAutograd in op_overload.py_kernels:
del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd]
def _(*args, **kwargs):
return NotImplemented
op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(_)
# For fake tensor prop, we do want to register meta kernel directly
if torch._C.DispatchKey.Meta not in op_overload.py_kernels:
op_overload.py_impl(torch._C.DispatchKey.Meta)(
functools.partial(_register_cia_to_meta, kernel=op_overload)
)
if op_overload in decomp_table:
warnings.warn(
f"Deleting decomposition registered for operator `{op_overload}`, "
"which was sepecified in the `preserve_ops` list."
)
removed_decomps[op_overload] = decomp_table[op_overload]
del decomp_table[op_overload]
try:
yield
finally:
for op in patched_ops:
op.py_kernels.clear()
op.py_kernels.update(saved_tables[op])
op._dispatch_cache.clear()
for op, decomp in removed_decomps.items():
decomp_table[op] = decomp
def _rename_without_collisions(
name_map: Dict[str, str],
orig_name: str,
name: str,
is_placeholder: bool = False,
):
"""
Renames nodes to avoid name collisions, with suffixing.
name_map: map from original name to new name
orig_name: mapping key
name: candidate name (potentially suffixed, e.g. mul_2)
is_placeholder: if the node is a placeholder, avoid detecting suffix
"""
if name in name_map.values():
# non-placeholder nodes may be suffixed with the count
# instead of adding another suffix, we will try to increment it
match = re.match(r"(.*)_(\d+)", name)
if match and not is_placeholder:
name, n = match.group(1), int(match.group(2))
else:
n = 0
while (dup_name := f"{name}_{n + 1}") in name_map.values():
n += 1
name_map[orig_name] = dup_name
else:
name_map[orig_name] = name
return name_map[orig_name]
def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
"""
Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs,
and handle collisions with non-placeholders by count suffixing.
Different HOO subgraph types have different input schemas, so we first enumerate them
and gather the top-level named placeholder nodes.
"""
# gather all HOO subgraphs and their top-level named placeholder nodes
subgraph_ph_tuples: List[Tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = []
for node in gm.graph.nodes:
if node.op == "call_function" and isinstance(
node.target, torch._ops.HigherOrderOperator
):
# HOO subgraphs have varying input schemas, so we enumerate them there
if node.target._name == "cond":
_, true_graph, false_graph, cond_args = node._args
subgraph_ph_tuples.append((getattr(gm, true_graph.target), cond_args))
subgraph_ph_tuples.append((getattr(gm, false_graph.target), cond_args))
elif node.target._name == "wrap_with_set_grad_enabled":
subgraph, phs = node._args[1], node._args[2:]
subgraph_ph_tuples.append((getattr(gm, subgraph.target), phs))
elif node.target._name == "map_impl":
body_graph, array, args = node._args
subgraph_ph_tuples.append(
(getattr(gm, body_graph.target), array + args)
)
# propagate names
for subgraph, hoo_phs in subgraph_ph_tuples:
name_map: Dict[str, str] = {}
for i, node in enumerate(subgraph.graph.nodes):
if i < len(hoo_phs): # placeholder, retain name
name_map[node.name] = hoo_phs[i].name
node.name = node.target = hoo_phs[i].name
else: # non-placeholder, check for collisions
node.name = _rename_without_collisions(name_map, node.name, node.name)
# recurse and recompile
_name_hoo_subgraph_placeholders(subgraph)
subgraph.recompile()
def _decompose_and_get_gm_with_new_signature_constants(
ep,
*,
decomp_table: Dict[torch._ops.OperatorBase, Callable],
_preserve_ops: Tuple[torch._ops.OpOverload],
joint_loss_index: Optional[int],
):
from torch._export.non_strict_utils import make_fake_params_buffers
from torch._export.passes.lift_constants_pass import ConstantAttrMap
from torch._functorch.aot_autograd import aot_export_module
from torch._guards import detect_fake_mode
from torch.export._trace import (
_export_to_aten_ir,
_get_params_buffers,
_ignore_backend_decomps,
_verify_nn_module_stack,
_verify_placeholder_names,
_verify_stack_trace,
)
if ep.verifier.dialect == "TRAINING":
mod = ep.module()
fake_args = []
for node in mod.graph.nodes:
if node.op == "placeholder":
fake_args.append(node.meta["val"])
fake_args_unwrapped = pytree.tree_unflatten(fake_args, mod._in_spec)
fake_mode = detect_fake_mode(fake_args)
# Fix the graph output signature to be tuple if scalar
out_spec = mod._out_spec
orig_arg_names = mod.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]
# aot_export expect the return type to always be a tuple.
if out_spec.type not in (list, tuple):
out_spec = pytree.TreeSpec(tuple, None, [out_spec])
mod.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo(
orig_arg_names,
mod._in_spec,
out_spec,
)
)
mod.recompile()
# the exported module will store constants & non-persistent buffers such that
# retracing treats them as persistent buffers, so we inform the constants lifting pass
# and overwrite the new graph signature using the previous program.
constant_attrs = ConstantAttrMap()
non_persistent_buffers = {
spec.target
for spec in ep.graph_signature.input_specs
if spec.kind == InputKind.BUFFER and not spec.persistent
}
for name, value in ep.constants.items():
if name in non_persistent_buffers:
continue
# recursive getattr
_mod = mod
*atoms, attr = name.split(".")
for atom in atoms:
_mod = getattr(_mod, atom)
# remove as buffer, reassign as constant/non-persistent buffer
_mod._buffers.pop(attr, None)
setattr(_mod, attr, value)
constant_attrs.add(value, name)
# get params & buffers after excluding constants
fake_params_buffers = make_fake_params_buffers(
fake_mode, _get_params_buffers(mod)
)
aten_export_artifact = _export_to_aten_ir(
mod,
# this requires empty kwargs, but not in pytree.flattened format
(
*fake_args_unwrapped[0],
*fake_args_unwrapped[1].values(),
),
{},
fake_params_buffers,
constant_attrs,
)
gm = aten_export_artifact.gm
new_graph_signature = aten_export_artifact.sig
for node in gm.graph.nodes:
# nn_module_stack
if node.op not in ["placeholder", "output"]:
for key, (fqn, mod_cls) in node.meta["nn_module_stack"].items():
if isinstance(mod_cls, type):
node.meta["nn_module_stack"][key] = (
fqn,
mod_cls.__module__ + "." + mod_cls.__qualname__,
)
# overwrite signature for non-persistent buffers
for spec in new_graph_signature.input_specs:
if spec.kind == InputKind.BUFFER and spec.target in non_persistent_buffers:
spec.persistent = False
_verify_nn_module_stack(gm)
_verify_stack_trace(gm)
_verify_placeholder_names(gm, new_graph_signature)
return gm, new_graph_signature
old_placeholders = [
node for node in ep.graph_module.graph.nodes if node.op == "placeholder"
]
fake_args = [node.meta["val"] for node in old_placeholders]
buffers_to_remove = [name for name, _ in ep.graph_module.named_buffers()]
for name in buffers_to_remove:
delattr(ep.graph_module, name)
from torch._guards import detect_fake_mode
# TODO(zhxhchen17) Return the new graph_signature directly.
fake_mode = detect_fake_mode(fake_args)
fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode
with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp(
_preserve_ops,
decomp_table,
):
gm, graph_signature = aot_export_module(
ep.graph_module,
fake_args,
decompositions=decomp_table,
trace_joint=True if joint_loss_index is not None else False,
output_loss_index=joint_loss_index
if joint_loss_index is not None
else None,
)
# Update the signatures with the new placeholder names in case they
# changed when calling aot_export
def update_arg(old_arg, new_ph):
if isinstance(old_arg, ConstantArgument):
return old_arg
elif isinstance(old_arg, TensorArgument):
return TensorArgument(name=new_ph.name)
elif isinstance(old_arg, SymIntArgument):
return SymIntArgument(name=new_ph.name)
raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}")
new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
new_outputs = list(gm.graph.nodes)[-1].args[0]
# rename the placeholders
assert len(new_placeholders) == len(old_placeholders)
for old_ph, new_ph in zip(old_placeholders, new_placeholders):
new_ph.name = new_ph.target = old_ph.name
# handle name collisions with newly decomposed graph nodes
name_map = {ph.name: ph.name for ph in new_placeholders}
for node in gm.graph.nodes:
if node.op == "placeholder":
continue
node.name = _rename_without_collisions(name_map, node.name, node.name)
# propagate names to higher order op subgraphs
_name_hoo_subgraph_placeholders(gm)
# Run this pass before creating input/output specs, since size-related CSE/DCE might affect output signature.
# Overwrite output specs afterwards.
from torch._export.passes._node_metadata_hook import (
_node_metadata_hook,
_set_node_metadata_hook,
)
from torch._functorch._aot_autograd.input_output_analysis import _graph_output_names
if not torch._dynamo.config.do_not_emit_runtime_asserts:
stack_trace = (
'File "torch/fx/passes/runtime_assert.py", line 24, '
"in insert_deferred_runtime_asserts"
)
shape_env = _get_shape_env(gm)
if shape_env is not None:
with _set_node_metadata_hook(
gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace)
):
insert_deferred_runtime_asserts(
gm,
shape_env,
f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
export=True,
)
# update output specs
gm.recompile()
for i, name in enumerate(_graph_output_names(gm)):
if isinstance(new_outputs[i], torch.fx.Node):
new_outputs[i].name = name
# To match the output target with correct input for input mutations
# need to find the old to new placeholder map
old_new_placeholder_map = {
spec.arg.name: new_placeholders[i].name
for i, spec in enumerate(ep.graph_signature.input_specs)
if not isinstance(spec.arg, ConstantArgument)
}
input_specs = [
InputSpec(
spec.kind,
update_arg(spec.arg, new_placeholders[i]),
spec.target,
spec.persistent,
)
for i, spec in enumerate(ep.graph_signature.input_specs)
]
output_specs = [
OutputSpec(
spec.kind,
update_arg(spec.arg, new_outputs[i]),
old_new_placeholder_map.get(spec.target, spec.target),
)
for i, spec in enumerate(ep.graph_signature.output_specs)
]
if joint_loss_index is not None:
assert graph_signature.backward_signature is not None
gradients = graph_signature.backward_signature.gradients_to_user_inputs
assert len(graph_signature.user_inputs) == len(ep.graph_signature.input_specs)
specs = {
graph_signature.user_inputs[i]: spec
for i, spec in enumerate(ep.graph_signature.input_specs)
if isinstance(spec.arg, TensorArgument)
}
for i, node in enumerate(new_outputs[len(output_specs) :]):
source = gradients[node.name]
spec = specs[source] # type: ignore[index]
if spec.kind == InputKind.PARAMETER:
kind = OutputKind.GRADIENT_TO_PARAMETER
target = spec.target
elif spec.kind == InputKind.USER_INPUT:
kind = OutputKind.GRADIENT_TO_USER_INPUT
target = source
else:
raise AssertionError(f"Unknown input kind: {spec.kind}")
output_specs.append(
OutputSpec(
kind,
TensorArgument(name=node.name),
target,
)
)
assert len(new_placeholders) == len(old_placeholders)
new_graph_signature = ExportGraphSignature(
input_specs=input_specs, output_specs=output_specs
)
# NOTE: aot_export adds symint metadata for placeholders with int
# values; since these become specialized, we replace such metadata with
# the original values.
# Also, set the param/buffer metadata back to the placeholders.
for old_node, new_node in zip(old_placeholders, new_placeholders):
if not isinstance(old_node.meta["val"], torch.Tensor):
new_node.meta["val"] = old_node.meta["val"]
if (
new_node.target in new_graph_signature.inputs_to_parameters
or new_node.target in new_graph_signature.inputs_to_buffers
):
for k, v in old_node.meta.items():
new_node.meta[k] = v
return gm, new_graph_signature
def _decompose_exported_program(
ep,
*,
decomp_table: Dict[torch._ops.OperatorBase, Callable],
_preserve_ops: Tuple[torch._ops.OpOverload],
joint_loss_index: Optional[int],
):
gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants(
ep,
decomp_table=decomp_table,
_preserve_ops=_preserve_ops,
joint_loss_index=joint_loss_index,
)
# TODO unfortunately preserving graph-level metadata is not
# working well with aot_export. So we manually copy it.
# (The node-level meta is addressed above.)
gm.meta.update(ep.graph_module.meta)
new_range_constraints = _get_updated_range_constraints(
gm,
ep.range_constraints,
_is_executorch=False,
)
exported_program = ExportedProgram(
root=gm,
graph=gm.graph,
graph_signature=new_graph_signature,
state_dict=ep.state_dict,
range_constraints=new_range_constraints,
module_call_graph=copy.deepcopy(ep.module_call_graph),
example_inputs=ep.example_inputs,
constants=ep.constants,
)
return exported_program
class ExportedProgram:
"""
Package of a program from :func:`export`. It contains
an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing
tensor values of all lifted parameters and buffers, and various metadata.
You can call an ExportedProgram like the original callable traced by
:func:`export` with the same calling convention.
To perform transformations on the graph, use ``.module`` property to access
an :class:`torch.fx.GraphModule`. You can then use
`FX transformation <https://pytorch.org/docs/stable/fx.html#writing-transformations>`_
to rewrite the graph. Afterwards, you can simply use :func:`export`
again to construct a correct ExportedProgram.
"""
def __init__(
self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: torch.fx.Graph,
graph_signature: ExportGraphSignature,
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
range_constraints: "Dict[sympy.Symbol, Any]",
module_call_graph: List[ModuleCallEntry],
example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
verifier: Optional[Type[Any]] = None, # TODO Deprecate this.
tensor_constants: Optional[
Dict[str, torch.Tensor]
] = None, # TODO: deprecate this
constants: Optional[
Dict[str, Union[torch.Tensor, FakeScriptObject, torch._C.ScriptObject]]
] = None,
*,
verifiers: Optional[List[Type[Verifier]]] = None,
):
# Remove codegen related things from the graph. It should just be a flat graph.
graph._codegen = torch.fx.graph.CodeGen()
self._graph_module = _create_graph_module_for_export(root, graph)
if isinstance(root, torch.fx.GraphModule):
self._graph_module.meta.update(root.meta)
self._graph_signature: ExportGraphSignature = graph_signature
self._state_dict: Dict[str, Any] = state_dict
self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
assert module_call_graph is not None
self._module_call_graph: List[ModuleCallEntry] = module_call_graph
self._example_inputs = example_inputs
self._constants = tensor_constants or constants or {}
assert self._constants is not None
# TODO Clean up this after we bump executorch's pin.
assert verifier is None or verifiers is None
if verifiers is None:
if verifier is None:
verifiers = [Verifier]
else:
verifiers = [verifier]
assert all(issubclass(v, Verifier) for v in verifiers)
self._verifiers = verifiers
# Validate should be always the last step of the constructor.
self._validate()
@property
@compatibility(is_backward_compatible=False)
def graph_module(self):
return self._graph_module
@property
@compatibility(is_backward_compatible=False)
def graph(self):
return self.graph_module.graph
@property
@compatibility(is_backward_compatible=False)
def graph_signature(self):
return self._graph_signature
@property
@compatibility(is_backward_compatible=False)
def state_dict(self):
return self._state_dict
@compatibility(is_backward_compatible=False)
def parameters(self) -> Iterator[torch.nn.Parameter]:
"""
Returns an iterator over original module's parameters.
"""
for _, param in self.named_parameters():
yield param
@compatibility(is_backward_compatible=False)
def named_parameters(self) -> Iterator[Tuple[str, torch.nn.Parameter]]:
"""
Returns an iterator over original module parameters, yielding
both the name of the parameter as well as the parameter itself.
"""
for param_name in self.graph_signature.parameters:
yield param_name, self.state_dict[param_name]
@compatibility(is_backward_compatible=False)
def buffers(self) -> Iterator[torch.Tensor]:
"""
Returns an iterator over original module buffers.
"""
for _, buf in self.named_buffers():
yield buf
@compatibility(is_backward_compatible=False)
def named_buffers(self) -> Iterator[Tuple[str, torch.Tensor]]:
"""
Returns an iterator over original module buffers, yielding
both the name of the buffer as well as the buffer itself.
"""
non_persistent_buffers = set(self.graph_signature.non_persistent_buffers)
for buffer_name in self.graph_signature.buffers:
if buffer_name in non_persistent_buffers:
yield buffer_name, self.constants[buffer_name]
else:
yield buffer_name, self.state_dict[buffer_name]
@property
@compatibility(is_backward_compatible=False)
def range_constraints(self):
return self._range_constraints
@property
@compatibility(is_backward_compatible=False)
def module_call_graph(self):
return self._module_call_graph
@property
@compatibility(is_backward_compatible=False)
def example_inputs(self):
return self._example_inputs
@property
@compatibility(is_backward_compatible=False)
def call_spec(self):
CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"])
if len(self.module_call_graph) == 0:
return CallSpec(in_spec=None, out_spec=None)
assert self.module_call_graph[0].fqn == ""
return CallSpec(
in_spec=self.module_call_graph[0].signature.in_spec,
out_spec=self.module_call_graph[0].signature.out_spec,
)
@property
@compatibility(is_backward_compatible=False)
def verifier(self) -> Any:
return self._verifiers[0]
@property
@compatibility(is_backward_compatible=False)
def dialect(self) -> str:
assert self._verifiers is not None
return self._verifiers[0].dialect
@property
@compatibility(is_backward_compatible=False)
def verifiers(self):
return self._verifiers
@property
@compatibility(is_backward_compatible=False)
def tensor_constants(self):
return self._constants
@property
@compatibility(is_backward_compatible=False)
def constants(self):
return self._constants
def _get_flat_args_with_check(self, args, kwargs):
"""Flatten args, kwargs using pytree, then, check specs.
Args:
args: List[Any] original args passed to __call__
kwargs: Dict[str, Any] original kwargs passed to __call
Returns:
A tuple of (flat_args, received_spec)
flat_args is flattend args / kwargs
received_spec is the pytree spec produced while flattening the
tuple (args, kwargs)
"""
in_spec = self.call_spec.in_spec
if in_spec is not None:
kwargs = reorder_kwargs(kwargs, in_spec)
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
(args, kwargs)
) # type: ignore[possibly-undefined]
self._check_input_constraints(flat_args_with_path)
flat_args = tuple(x[1] for x in flat_args_with_path)
return flat_args, received_spec
def _graph_module_flat_inputs(self, args: Any, kwargs: Any) -> Any:
"""Transform args, kwargs of __call__ to args for graph_module.
self.graph_module takes stuff from state dict as inputs.
The invariant is for ep: ExportedProgram is
ep(args, kwargs) ==
ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs)))
"""
in_spec = self.call_spec.in_spec
flat_args, received_spec = self._get_flat_args_with_check(args, kwargs)
if in_spec is not None and not is_equivalent(
received_spec, in_spec, _fx_collection_equivalence_fn
):
raise ValueError(
"Trying to flatten user inputs with exported input tree spec: \n"
f"{in_spec}\n"
"but actually got inputs with tree spec of: \n"
f"{received_spec}"
)
additional_inputs = []
for input_ in self.graph_signature.input_specs:
if input_.kind == InputKind.USER_INPUT:
continue
elif input_.kind in (
InputKind.PARAMETER,
InputKind.BUFFER,
):
if input_.persistent is False:
# This is a non-persistent buffer, grab it from our
# constants instead of the state dict.
additional_inputs.append(self.constants[input_.target])
else:
additional_inputs.append(self.state_dict[input_.target])
elif input_.kind in (
InputKind.CONSTANT_TENSOR,
InputKind.CUSTOM_OBJ,
):
additional_inputs.append(self.constants[input_.target])
additional_inputs = tuple(additional_inputs)
# NOTE: calling convention is first params, then buffers, then args as user supplied them.
# See: torch/_functorch/aot_autograd.py#L1034
return additional_inputs + flat_args
def __call__(self, *args: Any, **kwargs: Any) -> Any:
raise RuntimeError(
"Unable to call ExportedProgram directly. "
"You should use `exported_program.module()` instead."
)
def _postprocess_graph_module_outputs(self, res, orig_args, orig_kwargs):
"""Process potential mutations to the input.
Because self.graph_module is functional, so mutations has to be written
back after execution of graph_module.
"""
import torch._export.error as error
flat_args, _ = self._get_flat_args_with_check(orig_args, orig_kwargs)
if self.call_spec.out_spec is not None:
buffer_mutation = self.graph_signature.buffers_to_mutate
user_input_mutation = self.graph_signature.user_inputs_to_mutate
num_mutated = len(buffer_mutation) + len(user_input_mutation)
mutated_values = res[:num_mutated]
# Exclude dependency token from final result.
assertion_dep_token = self.graph_signature.assertion_dep_token
if assertion_dep_token is not None:
assertion_dep_token_index = next(iter(assertion_dep_token.keys()))
res = res[:assertion_dep_token_index]
res = res[num_mutated:]
try:
res = pytree.tree_unflatten(res, self.call_spec.out_spec)
except Exception:
_, received_spec = pytree.tree_flatten(res)
raise error.InternalError( # noqa: B904
"Trying to flatten user outputs with exported output tree spec: \n"
f"{self.call_spec.out_spec}\n"
"but actually got outputs with tree spec of: \n"
f"{received_spec}"
)
finally:
user_inputs = [
spec
for spec in self.graph_signature.input_specs
if spec.kind == InputKind.USER_INPUT
]
for i, value in enumerate(mutated_values):
output_spec = self.graph_signature.output_specs[i]
if output_spec.kind == OutputKind.BUFFER_MUTATION:
assert output_spec.target is not None
self.state_dict[output_spec.target] = value
elif output_spec.kind == OutputKind.USER_INPUT_MUTATION:
assert output_spec.target is not None
index = next(
i
for i, spec in enumerate(user_inputs)
if spec.arg.name == output_spec.target
)
flat_args[index].copy_(value)
else:
raise AssertionError(f"Unexpected kind: {output_spec.kind}")
return res
def __str__(self) -> str:
graph_module = self.graph_module.print_readable(
print_output=False, colored=True
).replace("\n", "\n ")
string = (
"ExportedProgram:\n"
f" {graph_module}\n"
f"Graph signature: {self.graph_signature}\n"
f"Range constraints: {self.range_constraints}\n"
)
return string
def module(self) -> torch.nn.Module:
"""
Returns a self contained GraphModule with all the parameters/buffers inlined.
"""
from ._unlift import _unlift_exported_program_lifted_states
module = _unlift_exported_program_lifted_states(self)
def _train(self, mode: bool = True):
raise NotImplementedError("Calling train() is not supported yet.")
def _eval(self, mode: bool = True):
raise NotImplementedError("Calling eval() is not supported yet.")
module.train = types.MethodType(_train, module) # type: ignore[method-assign]
module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
return module
def _num_lifted_params_buffers(self):
return next(
(
i
for i, s in enumerate(self._graph_signature.input_specs)
if s.kind == InputKind.USER_INPUT
),
len(self._graph_signature.input_specs),
)
@_disable_prexisiting_fake_mode
def run_decompositions(
self,
decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
_preserve_ops: Tuple[torch._ops.OpOverload] = (), # type: ignore[assignment]
) -> "ExportedProgram":
"""
Run a set of decompositions on the exported program and returns a new
exported program. By default we will run the Core ATen decompositions to
get operators in the
`Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_.
For now, we do not decompose joint graphs.
"""
from torch._decomp import core_aten_decompositions
if decomp_table is None:
decomp_table = core_aten_decompositions()
return _decompose_exported_program(
self,
decomp_table=decomp_table,
_preserve_ops=_preserve_ops, # type: ignore[arg-type]
joint_loss_index=None,
)
def _transform_do_not_use(self, *passes: PassType) -> "ExportedProgram":
pm = PassManager(list(passes))
# Since we abstractly run the passes, we need to disable backend decomp here
# again.
from torch.export._trace import _ignore_backend_decomps
with _ignore_backend_decomps():
res = pm(self.graph_module)
transformed_gm = res.graph_module if res is not None else self.graph_module
assert transformed_gm is not None
if transformed_gm is self.graph_module and not res.modified:
return self
# TODO(zhxchen17) Remove this.
def _get_updated_graph_signature(
old_signature: ExportGraphSignature,
new_gm: torch.fx.GraphModule,
) -> ExportGraphSignature:
"""
Update the graph signature's user_input/user_outputs.
"""
new_input_specs = []
for i, node in enumerate(new_gm.graph.nodes):
if node.op != "placeholder":
break
assert i < len(
old_signature.input_specs
), "Number of inputs changed after transformation"
old_input_spec = old_signature.input_specs[i]
arg = (
old_input_spec.arg
if isinstance(
old_input_spec.arg, (ConstantArgument, CustomObjArgument)
)
else type(old_input_spec.arg)(node.name)
)
new_input_specs.append(
InputSpec(
old_input_spec.kind,
arg,
old_input_spec.target,
old_input_spec.persistent,
)
)
output_node = list(new_gm.graph.nodes)[-1]
assert output_node.op == "output"
new_output_specs = []
for i, node in enumerate(output_node.args[0]):
assert i < len(
old_signature.output_specs
), "Number of outputs changed after transformation"
old_output_spec = old_signature.output_specs[i]
arg = (
old_output_spec.arg
if isinstance(
old_output_spec.arg, (ConstantArgument, CustomObjArgument)
)
else type(old_output_spec.arg)(node.name)
)
new_output_specs.append(
OutputSpec(old_output_spec.kind, arg, old_output_spec.target)
)
new_signature = ExportGraphSignature(
input_specs=new_input_specs, output_specs=new_output_specs
)
return new_signature
transformed_ep = ExportedProgram(
root=transformed_gm,
graph=transformed_gm.graph,
graph_signature=_get_updated_graph_signature(
self.graph_signature, transformed_gm
),
state_dict=self.state_dict,
range_constraints=_get_updated_range_constraints(
transformed_gm,
self.range_constraints,
_is_executorch=False,
),
module_call_graph=copy.deepcopy(self._module_call_graph),
example_inputs=self.example_inputs,
verifier=self.verifier,
constants=self.constants,
)
transformed_ep.graph_module.meta.update(self.graph_module.meta)
transformed_ep.graph_module.meta.update(res.graph_module.meta)
return transformed_ep
def _check_input_constraints(self, flat_args_with_path):
from torch._export.utils import _check_input_constraints_for_graph
placeholders = [p for p in self.graph.nodes if p.op == "placeholder"]
input_placeholders = [
p
for p, s in zip(placeholders, self.graph_signature.input_specs)
if s.kind == InputKind.USER_INPUT
]
_check_input_constraints_for_graph(
input_placeholders, flat_args_with_path, self.range_constraints
)
@final
def _validate(self):
assert (
len(self.verifiers) > 0
), "ExportedProgram must have at least one verifier."
for v in self.verifiers:
v().check(self)
# TODO(zhxchen17) Formalize this.
def _update(
self, graph_module, graph_signature, *, state_dict=None, verifiers=None
) -> "ExportedProgram":
return ExportedProgram(
root=graph_module,
graph=graph_module.graph,
graph_signature=graph_signature,
state_dict=state_dict if state_dict is not None else self.state_dict,
range_constraints=copy.deepcopy(self.range_constraints),
module_call_graph=copy.deepcopy(self._module_call_graph),
example_inputs=self.example_inputs,
verifiers=verifiers if verifiers is not None else self.verifiers,
constants=self.constants,
)
def _get_shape_env(gm):
vals = [
node.meta["val"]
for node in gm.graph.nodes
if node.meta.get("val", None) is not None
]
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(vals)
if fake_mode is not None:
return fake_mode.shape_env
for v in vals:
if isinstance(v, torch.SymInt):
return v.node.shape_env
def _get_updated_range_constraints(
gm: torch.fx.GraphModule,
old_range_constraints: "Optional[Dict[sympy.Symbol, Any]]" = None,
_is_executorch: bool = True,
) -> "Dict[sympy.Symbol, Any]":
# FIXME(tmanlaibaatar) Remove this whole branch once https://github.com/pytorch/pytorch/pull/123764
if _is_executorch:
assert old_range_constraints is None
shape_env = _get_shape_env(gm)
if shape_env is None:
return {}
range_constraints = {
k: v
for k, v in shape_env.var_to_range.items()
if k not in shape_env.replacements
}
# Only when we have an unbacked symint, and it's used as constructor inputs,
# runtime_var_to_range will make a difference compated to var_to_range.
# e.g. [2, oo) -> [0, oo)
for k, v in shape_env.var_to_range.items():
if k not in shape_env.replacements:
range_constraints[k] = v
return range_constraints
assert old_range_constraints is not None
shape_env = _get_shape_env(gm)
if shape_env is None:
return {}
range_constraints = copy.copy(old_range_constraints)
range_constraints = {
k: v for k, v in range_constraints.items() if k not in shape_env.replacements
}
# Only when we have an unbacked symint, and it's used as constructor inputs,
# runtime_var_to_range will make a difference compated to var_to_range.
# e.g. [2, oo) -> [0, oo)
for k, v in shape_env.var_to_range.items():
if k not in shape_env.replacements and k not in range_constraints:
range_constraints[k] = v
return range_constraints
def _create_graph_module_for_export(root, graph):
try:
gm = torch.fx.GraphModule(root, graph)
except SyntaxError:
# If custom objects stored in memory are being used in the graph,
# the generated python code will result in a syntax error on the custom
# object, since it is unable to parse the in-memory object. However
# we can still run the graph eagerly through torch.fx.Interpreter,
# so we will bypass this error.
warnings.warn(
"Unable to execute the generated python source code from "
"the graph. The graph module will no longer be directly callable, "
"but you can still run the ExportedProgram, and if needed, you can "
"run the graph module eagerly using torch.fx.Interpreter."
)
gm = torch.fx.GraphModule(root, torch.fx.Graph())
gm._graph = graph
return gm