Files
pytorch/torch/export/exported_program.py
Sherlock Huang bb3db079b1 [Export] Introduce class_fqn into CustomObjArgument (#118158)
Summary:
Class FQN is needed when unpacking CustomObj instance.
For all other Arguments, e.g. Tensor, TensorList, SymInt, we always know their exact type. However, CustomObjArgument had an opaque type.
Adding this field also helps unveiling the type of this opaque object.

Test Plan: CI

Differential Revision: D53029847

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118158
Approved by: https://github.com/zhxchen17
2024-01-25 18:44:25 +00:00

633 lines
23 KiB
Python

import copy
import dataclasses
import functools
import types
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code.
import sympy
from torch.utils._sympy.value_ranges import ValueRanges
import torch
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
from torch.fx._compatibility import compatibility
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager
from .graph_signature import ( # noqa: F401
_sig_to_specs,
ArgumentSpec,
ConstantArgument,
CustomObjArgument,
ExportGraphSignature,
InputKind,
InputSpec,
OutputKind,
OutputSpec,
SymIntArgument,
TensorArgument,
)
__all__ = [
"ExportedProgram",
"ModuleCallEntry",
"ModuleCallSignature",
]
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
@dataclasses.dataclass
class ModuleCallSignature:
inputs: List[ArgumentSpec]
outputs: List[ArgumentSpec]
in_spec: pytree.TreeSpec
out_spec: pytree.TreeSpec
@dataclasses.dataclass
class ModuleCallEntry:
fqn: str
signature: Optional[ModuleCallSignature] = None
def _disable_prexisiting_fake_mode(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
with maybe_disable_fake_tensor_mode():
return fn(*args, **kwargs)
return wrapper
class ExportedProgram:
"""
Package of a program from :func:`export`. It contains
an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing
tensor values of all lifted parameters and buffers, and various metadata.
You can call an ExportedProgram like the original callable traced by
:func:`export` with the same calling convention.
To perform transformations on the graph, use ``.module`` property to access
an :class:`torch.fx.GraphModule`. You can then use
`FX transformation <https://pytorch.org/docs/stable/fx.html#writing-transformations>`_
to rewrite the graph. Afterwards, you can simply use :func:`export`
again to construct a correct ExportedProgram.
"""
def __init__(
self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: torch.fx.Graph,
graph_signature: ExportGraphSignature,
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
range_constraints: "Dict[sympy.Symbol, Any]",
module_call_graph: List[ModuleCallEntry],
example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
verifier: Optional[Type[Any]] = None, # TODO Change typing hint to Verifier.
tensor_constants: Optional[
Dict[str, torch.Tensor]
] = None, # TODO: deprecate this
constants: Optional[
Dict[str, Union[torch.Tensor, torch._C.ScriptObject]]
] = None,
):
from torch._export.exported_program import _create_graph_module_for_export
# Remove codegen related things from the graph. It should just be a flat graph.
graph._codegen = torch.fx.graph.CodeGen()
self._graph_module = _create_graph_module_for_export(root, graph)
if isinstance(root, torch.fx.GraphModule):
self._graph_module.meta.update(root.meta)
self._graph_signature: ExportGraphSignature = graph_signature
self._state_dict: Dict[str, Any] = state_dict
self._range_constraints: "Dict[sympy.Symbol, ValueRanges]" = range_constraints
assert module_call_graph is not None
self._module_call_graph: List[ModuleCallEntry] = module_call_graph
self._example_inputs = example_inputs
self._constants = tensor_constants or constants or {}
assert self._constants is not None
from torch._export.verifier import Verifier
if verifier is None:
verifier = Verifier
assert issubclass(verifier, Verifier)
self._verifier = verifier
# Validate should be always the last step of the constructor.
self.verifier().check(self)
@property
@compatibility(is_backward_compatible=False)
def graph_module(self):
return self._graph_module
@property
@compatibility(is_backward_compatible=False)
def graph(self):
return self.graph_module.graph
@property
@compatibility(is_backward_compatible=False)
def graph_signature(self):
return self._graph_signature
@property
@compatibility(is_backward_compatible=False)
def state_dict(self):
return self._state_dict
@compatibility(is_backward_compatible=False)
def parameters(self) -> Iterator[torch.nn.Parameter]:
"""
Returns an iterator over original module's parameters.
"""
for _, param in self.named_parameters():
yield param
@compatibility(is_backward_compatible=False)
def named_parameters(self) -> Iterator[Tuple[str, torch.nn.Parameter]]:
"""
Returns an iterator over original module parameters, yielding
both the name of the parameter as well as the parameter itself.
"""
for param_name in self.graph_signature.parameters:
yield param_name, self.state_dict[param_name]
@compatibility(is_backward_compatible=False)
def buffers(self) -> Iterator[torch.Tensor]:
"""
Returns an iterator over original module buffers.
"""
for _, buf in self.named_buffers():
yield buf
@compatibility(is_backward_compatible=False)
def named_buffers(self) -> Iterator[Tuple[str, torch.Tensor]]:
"""
Returns an iterator over original module buffers, yielding
both the name of the buffer as well as the buffer itself.
"""
for buffer_name in self.graph_signature.buffers:
yield buffer_name, self.state_dict[buffer_name]
@property
@compatibility(is_backward_compatible=False)
def range_constraints(self):
return self._range_constraints
@property
@compatibility(is_backward_compatible=False)
def module_call_graph(self):
return self._module_call_graph
@property
@compatibility(is_backward_compatible=False)
def example_inputs(self):
return self._example_inputs
@property
@compatibility(is_backward_compatible=False)
def call_spec(self):
from torch._export.exported_program import CallSpec
if len(self.module_call_graph) == 0:
return CallSpec(in_spec=None, out_spec=None)
assert self.module_call_graph[0].fqn == ""
return CallSpec(
in_spec=self.module_call_graph[0].signature.in_spec,
out_spec=self.module_call_graph[0].signature.out_spec,
)
@property
@compatibility(is_backward_compatible=False)
def verifier(self) -> Any:
return self._verifier
@property
@compatibility(is_backward_compatible=False)
def dialect(self) -> str:
return self._verifier.dialect
@property
@compatibility(is_backward_compatible=False)
def tensor_constants(self):
return self._constants
@property
@compatibility(is_backward_compatible=False)
def constants(self):
return self._constants
def __call__(self, *args: Any, **kwargs: Any) -> Any:
import torch._export.error as error
if self.call_spec.in_spec is not None:
try:
user_args = (args, kwargs or {})
args = fx_pytree.tree_flatten_spec(
user_args, self.call_spec.in_spec, exact_structural_match=True
) # type: ignore[assignment]
except Exception:
_, received_spec = pytree.tree_flatten(user_args)
raise TypeError( # noqa: TRY200
"Trying to flatten user inputs with exported input tree spec: \n"
f"{self.call_spec.in_spec}\n"
"but actually got inputs with tree spec of: \n"
f"{received_spec}"
)
additional_inputs = []
for input_ in self.graph_signature.input_specs:
if input_.kind == InputKind.USER_INPUT:
continue
elif input_.kind in (InputKind.PARAMETER, InputKind.BUFFER):
additional_inputs.append(self.state_dict[input_.target])
elif input_.kind in (InputKind.CONSTANT_TENSOR, InputKind.CUSTOM_OBJ):
additional_inputs.append(self.constants[input_.target])
additional_inputs = tuple(additional_inputs)
self._check_input_constraints(*args)
# NOTE: calling convention is first params, then buffers, then args as user supplied them.
# See: torch/_functorch/aot_autograd.py#L1034
res = torch.fx.Interpreter(self.graph_module).run(
*additional_inputs,
*args,
enable_io_processing=False,
)
if self.call_spec.out_spec is not None:
buffer_mutation = self.graph_signature.buffers_to_mutate
user_input_mutation = self.graph_signature.user_inputs_to_mutate
num_mutated = len(buffer_mutation) + len(user_input_mutation)
mutated_values = res[:num_mutated]
# Exclude dependency token from final result.
assertion_dep_token = self.graph_signature.assertion_dep_token
if assertion_dep_token is not None:
assertion_dep_token_index = next(iter(assertion_dep_token.keys()))
res = res[:assertion_dep_token_index]
res = res[num_mutated:]
try:
res = pytree.tree_unflatten(res, self.call_spec.out_spec)
except Exception:
_, received_spec = pytree.tree_flatten(res)
raise error.InternalError( # noqa: TRY200
"Trying to flatten user outputs with exported output tree spec: \n"
f"{self.call_spec.out_spec}\n"
"but actually got outputs with tree spec of: \n"
f"{received_spec}"
)
finally:
user_inputs = [
spec
for spec in self.graph_signature.input_specs
if spec.kind == InputKind.USER_INPUT
]
for i, value in enumerate(mutated_values):
output_spec = self.graph_signature.output_specs[i]
if output_spec.kind == OutputKind.BUFFER_MUTATION:
assert output_spec.target is not None
self.state_dict[output_spec.target] = value
elif output_spec.kind == OutputKind.USER_INPUT_MUTATION:
assert output_spec.target is not None
index = next(
i
for i, spec in enumerate(user_inputs)
if spec.arg.name == output_spec.target
)
args[index].copy_(value)
else:
raise AssertionError(f"Unexpected kind: {output_spec.kind}")
return res
def __str__(self) -> str:
graph_module = self.graph_module.print_readable(print_output=False).replace(
"\n", "\n "
)
string = (
"ExportedProgram:\n"
f" {graph_module}\n"
f"Graph signature: {self.graph_signature}\n"
f"Range constraints: {self.range_constraints}\n"
)
return string
def module(self) -> torch.nn.Module:
"""
Returns a self contained GraphModule with all the parameters/buffers inlined.
"""
from ._unlift import _unlift_exported_program_lifted_states
module = _unlift_exported_program_lifted_states(self)
def _train(self, mode: bool = True):
raise NotImplementedError("Calling train() is not supported yet.")
def _eval(self, mode: bool = True):
raise NotImplementedError("Calling eval() is not supported yet.")
module.train = types.MethodType(_train, module) # type: ignore[method-assign]
module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
return module
@_disable_prexisiting_fake_mode
def run_decompositions(
self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None
) -> "ExportedProgram":
"""
Run a set of decompositions on the exported program and returns a new
exported program. By default we will run the Core ATen decompositions to
get operators in the
`Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_.
For now, we do not decompose joint graphs.
"""
from torch._decomp import core_aten_decompositions
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
_AddRuntimeAssertionsForInlineConstraintsPass,
)
from torch._export.passes.lift_constants_pass import lift_constants_pass
from torch._export.passes.replace_sym_size_ops_pass import (
_replace_sym_size_ops_pass,
)
from torch._functorch.aot_autograd import aot_export_module
def _get_placeholders(gm):
placeholders = []
for node in gm.graph.nodes:
if node.op != "placeholder":
break
placeholders.append(node)
return placeholders
decomp_table = decomp_table or core_aten_decompositions()
old_placeholders = _get_placeholders(self.graph_module)
fake_args = [node.meta["val"] for node in old_placeholders]
buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()]
for name in buffers_to_remove:
delattr(self.graph_module, name)
# TODO(zhxhchen17) Return the new graph_signature directly.
gm, graph_signature = aot_export_module(
self.graph_module, fake_args, decompositions=decomp_table, trace_joint=False
)
# Update the signatures with the new placeholder names in case they
# changed when calling aot_export
def update_arg(old_arg, new_ph):
if isinstance(old_arg, ConstantArgument):
return old_arg
elif isinstance(old_arg, TensorArgument):
return TensorArgument(name=new_ph.name)
elif isinstance(old_arg, SymIntArgument):
return SymIntArgument(name=new_ph.name)
raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}")
new_placeholders = _get_placeholders(gm)
new_outputs = list(gm.graph.nodes)[-1].args[0]
# To match the output target with correct input for input mutations
# need to find the old to new placeholder map
old_new_placeholder_map = {
spec.arg.name: new_placeholders[i].name
for i, spec in enumerate(self.graph_signature.input_specs)
if not isinstance(spec.arg, ConstantArgument)
}
input_specs = [
InputSpec(spec.kind, update_arg(spec.arg, new_placeholders[i]), spec.target)
for i, spec in enumerate(self.graph_signature.input_specs)
]
output_specs = [
OutputSpec(
spec.kind,
update_arg(spec.arg, new_outputs[i]),
old_new_placeholder_map.get(spec.target, spec.target),
)
for i, spec in enumerate(self.graph_signature.output_specs)
]
assert len(new_placeholders) == len(old_placeholders)
new_graph_signature = ExportGraphSignature(
input_specs=input_specs, output_specs=output_specs
)
# NOTE: aot_export adds symint metadata for placeholders with int
# values; since these become specialized, we replace such metadata with
# the original values.
# Also, set the param/buffer metadata back to the placeholders.
for old_node, new_node in zip(old_placeholders, new_placeholders):
if not isinstance(old_node.meta["val"], torch.Tensor):
new_node.meta["val"] = old_node.meta["val"]
if (
new_node.target in new_graph_signature.inputs_to_parameters
or new_node.target in new_graph_signature.inputs_to_buffers
):
for k, v in old_node.meta.items():
new_node.meta[k] = v
# TODO unfortunately preserving graph-level metadata is not
# working well with aot_export. So we manually copy it.
# (The node-level meta is addressed above.)
gm.meta.update(self.graph_module.meta)
new_range_constraints = _get_updated_range_constraints(gm)
constants = lift_constants_pass(gm, new_graph_signature)
for k, v in constants.items():
assert k not in self.constants
self.constants[k] = v
_replace_sym_size_ops_pass(gm)
exported_program = ExportedProgram(
root=gm,
graph=gm.graph,
graph_signature=new_graph_signature,
state_dict=self.state_dict,
range_constraints=new_range_constraints,
module_call_graph=copy.deepcopy(self.module_call_graph),
example_inputs=self.example_inputs,
verifier=self.verifier,
constants=self.constants,
)
if len(new_range_constraints) > 0:
exported_program = exported_program._transform_do_not_use(
_AddRuntimeAssertionsForInlineConstraintsPass(new_range_constraints)
)
return exported_program
def _transform_do_not_use(self, *passes: PassType) -> "ExportedProgram":
pm = PassManager(list(passes))
res = pm(self.graph_module)
transformed_gm = res.graph_module if res is not None else self.graph_module
assert transformed_gm is not None
if transformed_gm is self.graph_module and not res.modified:
return self
# TODO(zhxchen17) Remove this.
def _get_updated_graph_signature(
old_signature: ExportGraphSignature,
new_gm: torch.fx.GraphModule,
) -> ExportGraphSignature:
"""
Update the graph signature's user_input/user_outputs.
"""
new_input_specs = []
for i, node in enumerate(new_gm.graph.nodes):
if node.op != "placeholder":
break
assert i < len(
old_signature.input_specs
), "Number of inputs changed after transformation"
old_input_spec = old_signature.input_specs[i]
arg = (
old_input_spec.arg
if isinstance(
old_input_spec.arg, (ConstantArgument, CustomObjArgument)
)
else type(old_input_spec.arg)(node.name)
)
new_input_specs.append(
InputSpec(old_input_spec.kind, arg, old_input_spec.target)
)
output_node = list(new_gm.graph.nodes)[-1]
assert output_node.op == "output"
new_output_specs = []
for i, node in enumerate(output_node.args[0]):
assert i < len(
old_signature.output_specs
), "Number of outputs changed after transformation"
old_output_spec = old_signature.output_specs[i]
arg = (
old_output_spec.arg
if isinstance(
old_output_spec.arg, (ConstantArgument, CustomObjArgument)
)
else type(old_output_spec.arg)(node.name)
)
new_output_specs.append(
OutputSpec(old_output_spec.kind, arg, old_output_spec.target)
)
new_signature = ExportGraphSignature(
input_specs=new_input_specs, output_specs=new_output_specs
)
return new_signature
transformed_ep = ExportedProgram(
root=transformed_gm,
graph=transformed_gm.graph,
graph_signature=_get_updated_graph_signature(
self.graph_signature, transformed_gm
),
state_dict=self.state_dict,
range_constraints=_get_updated_range_constraints(transformed_gm),
module_call_graph=copy.deepcopy(self._module_call_graph),
example_inputs=self.example_inputs,
verifier=self.verifier,
constants=self.constants,
)
transformed_ep.graph_module.meta.update(self.graph_module.meta)
transformed_ep.graph_module.meta.update(res.graph_module.meta)
return transformed_ep
def _check_input_constraints(self, *args):
from torch._export.utils import _check_input_constraints_for_graph
placeholders = [p for p in self.graph.nodes if p.op == "placeholder"]
input_placeholders = [
p
for p, s in zip(placeholders, self.graph_signature.input_specs)
if s.kind == InputKind.USER_INPUT
]
_check_input_constraints_for_graph(
input_placeholders, args, self.range_constraints
)
def _validate(self):
self.verifier().check(self)
# TODO(zhxchen17) Formalize this.
def _update(
self, graph_module, graph_signature, state_dict=None
) -> "ExportedProgram":
return ExportedProgram(
root=graph_module,
graph=graph_module.graph,
graph_signature=graph_signature,
state_dict=state_dict or self.state_dict,
range_constraints=copy.deepcopy(self.range_constraints),
module_call_graph=copy.deepcopy(self._module_call_graph),
example_inputs=self.example_inputs,
verifier=self.verifier,
tensor_constants=self.tensor_constants,
)
def _get_updated_range_constraints(
gm: torch.fx.GraphModule,
) -> "Dict[sympy.Symbol, Any]":
def get_shape_env(gm):
vals = [
node.meta["val"]
for node in gm.graph.nodes
if node.meta.get("val", None) is not None
]
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(vals)
if fake_mode is not None:
return fake_mode.shape_env
for v in vals:
if isinstance(v, torch.SymInt):
return v.node.shape_env
shape_env = get_shape_env(gm)
if shape_env is None:
return {}
range_constraints = {
k: v
for k, v in shape_env.var_to_range.items()
if k not in shape_env.replacements
}
for k, v in shape_env.runtime_var_to_range.items():
if k not in shape_env.replacements:
range_constraints[k] = v
return range_constraints