Files
pytorch/torch/export/_unlift.py
Yuanyuan Chen fbe0d20a17 [2/N] More ruff SIM fixes (#165031)
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031
Approved by: https://github.com/mlazos
2025-10-14 14:22:54 +00:00

871 lines
32 KiB
Python

# mypy: allow-untyped-defs
import copy
import inspect
import math
import warnings
from collections.abc import Sequence
from itertools import chain
from typing import Any, Optional
import sympy
import torch
import torch.utils._pytree as pytree
from torch._export.non_strict_utils import (
_enter_enable_graph_inputs_of_type_nn_module,
_exit_enable_graph_inputs_of_type_nn_module,
_get_graph_inputs_of_type_nn_module,
)
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
_convert_range_to_int,
)
from torch._export.utils import _check_input_constraints_for_graph
from torch.export.unflatten import _assign_attr, _AttrKind
from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.fx.traceback import NodeSource, NodeSourceAction
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.value_ranges import ValueRanges
from ._remove_effect_tokens_pass import _remove_effect_tokens
from ._tree_utils import reorder_kwargs
from .exported_program import (
ExportedProgram,
ExportGraphSignature,
InputKind,
OutputKind,
)
def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool:
"""
Refinement of TreeSpec.__eq__ where, e.g., torch.Size(...) matches tuple(...).
See _pytree_subclasses_that_lose_info in proxy_tensor.py for more details.
"""
def _normalize_type(t):
return str(_pytree_subclasses_that_lose_info.get(t, t))
def _match_normalized_structure(a, b):
if a is b:
return True
if _normalize_type(a.type) != _normalize_type(b.type):
return False
if a.type is dict and b.type is dict:
# in the case of dict, the context is list of keys and we allow the keys to be in any order
if set(a.context) != set(b.context):
return False
elif a.context != b.context:
return False
if len(a.children_specs) != len(b.children_specs):
return False
return all(
_match_normalized_structure(a, b)
for a, b in zip(a.children_specs, b.children_specs)
)
return _match_normalized_structure(self, other)
def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> list:
reordered_kwargs = reorder_kwargs(kwargs, in_spec)
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
(args, reordered_kwargs)
)
if not eq_spec(received_spec, in_spec):
raise ValueError( # noqa: B904
"Trying to flatten user inputs with exported input tree spec: \n"
f"{in_spec}\n"
"but actually got inputs with tree spec of: \n"
f"{received_spec}.\n"
"Please check that the inputs have the same number and type of "
"args and kwargs as the ones you used when tracing."
)
return flat_args_with_path
def _force_ep_signature_match(ep_guards_code: list[str], input_paths):
# TODO (tmanlaibaatar)
# This is band-aid solution to export new tracer replacing
# shape env sources to flat_args. The real fix should be replacing
# shape env sources to original user sources but this is quite
# involved because you need to carefully construct new sources using
# dynamo and replace all instances of it inside shape env. But it is
# lot easier to manipulate after we turn them into strings and only
# time we use these guards is during retracing or running exported program,
# so it is probably ok to have "not useful" guards on ep for now.
name_mapping = {}
for idx, path in enumerate(input_paths):
name_mapping[f"L['flat_args'][{idx}]"] = f"L{pytree.keystr(path)}"
new_guards_code = []
for guard in ep_guards_code:
for old_name, new_name in name_mapping.items():
guard = guard.replace(old_name, new_name)
new_guards_code.append(guard)
return new_guards_code
def _force_gm_signature_match(ep_guards_code: list[str], signature):
"""
The signature of the originally exported module may not match
the signature of the unlifted graph module extracted from the
exported program. The guards code extracted from the exported
program is based on the former, but the generated guards fn is
based on the latter; thus we need to reconcile any such diff.
"""
import re
# Handle case where signatures may differ in var args.
orig_arg_names = set()
for g in ep_guards_code:
# match substrings of the form L['<name>'][<number>]
orig_arg_names.update(re.findall(r"L\[\'([^\']+)\'\]\[([0-9]+)\]", g))
sig_arg_names = set()
for n in signature.parameters:
# match substrings of the form <name>_<number>
sig_arg_names.update(re.findall(r"(.+)_([0-9]+)", n))
# replace L['<name>'][<number>] with L['<name>_<number>']
new_guards_code = ep_guards_code
for match in orig_arg_names:
if match in sig_arg_names:
base, idx = match
new_guards_code = [
g.replace(f"L['{base}'][{idx}]", f"L['{base}_{idx}']")
for g in new_guards_code
]
return new_guards_code
def _convert_guards_code_to_fn(
guards_code: list[str],
paths_of_placeholders: list[pytree.KeyPath],
):
"""
Generates Python code given guards code and paths of placeholders.
We assume that, based on source information,
- the tracer generates the guards code
- the input spec generates the paths of placeholders.
Example:
Suppose we are given the guards code "L['z']['k'].size()[1] == 3"
and we are given that ['z']['k'] is the path of placeholder #2.
Then we will generate:
```
torch._assert(
args[2].size()[0] == 3,
"Guard failed: z['k'].size()[0] == 3",
)
```
FAQ: Why do we generate code based on (flattened) args instead of
the original (unflattened) inputs? Because this would require
inserting an additional pytree.unflatten call in our graph.
FAQ: Why do we not emit RuntimeError on guard failure as we used to?
Because it is inconvenient :/, get used to AssertionError instead.
"""
import ast
from torch.fx.experimental.symbolic_shapes import SYMPY_INTERP
actual_guards_code = []
shadow_guards_code = []
for c in guards_code:
a, s = c, c
for idx, path in enumerate(paths_of_placeholders):
# e.g., replace L['z']['k'] with args[2] for Python code (actual)
a = a.replace("L" + pytree.keystr(path), f"args[{idx}]")
# e.g., replace L['z']['k'] with z['k'] for error message (shadow)
s = s.replace(
"L" + pytree.keystr(path),
path[0].key + pytree.keystr(path[1:]), # type: ignore[attr-defined]
)
actual_guards_code.append(a)
shadow_guards_code.append(s.replace("\n", ""))
# generate function code as str
code_str = "\ndef _(*args):\n"
for actual, shadow in zip(actual_guards_code, shadow_guards_code):
# printing guards code may potentially introduce redundant parens;
# we can normalize them out for readability by parsing/unparsing
# NOTE: this is not necessary for correctness, just deemed desirable
_shadow = ast.unparse(ast.parse(shadow, mode="eval"))
# actual code and shadow error message
code_str += f' torch._assert({actual}, "Guard failed: {_shadow}")\n'
code_str += " return\n"
# populate namespace with sympy globals, materialize function (named `_`)
namespace = {**SYMPY_INTERP}
exec(code_str, namespace)
# create and return a module whose forward is the materialized function
# NOTE: we want Dynamo to trace through this module, to repopulate guards:
# otherwise we would lose them when retracing
# NOTE: calling this module will be a side effect (no users): so it must
# be marked impure to avoid being not cleaned up by DCE
guards_fn = GuardsFn()
guards_fn.forward = torch._dynamo.dont_skip_tracing(namespace["_"]) # type: ignore[call-overload, method-assign]
guards_fn._is_impure = True # type: ignore[assignment]
return guards_fn
@torch._dynamo.disable
def _check_input_constraints_for_module(self, args, kwargs):
flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
_check_input_constraints_for_graph(
self.graph.find_nodes(op="placeholder"),
flat_args_with_path,
self.range_constraints,
)
def _check_input_constraints_pre_hook(self, args, kwargs):
# preserve current behavior for clients that do not want any validation
if not self.validate_inputs:
return
# when a guards function exists, assume that the graph does calls it!
# so we do not need to check input constraints...but we still want
# to check inputs match, otherwise we'd get obscure pytree errors
if hasattr(self, "_guards_fn"):
_check_inputs_match(args, kwargs, self._in_spec)
return
# NOTE: for some reason, Dynamo is tracing into this, we should see why and
# put compile at the right place. Until then, we can skip the input
# constraint checks.
if not torch.compiler.is_dynamo_compiling():
_check_input_constraints_for_module(self, args, kwargs)
def _unlift_inputs_as_getattr(
gm: torch.fx.GraphModule,
lifted_inputs: Sequence[Optional[str]],
) -> tuple[dict[str, torch.fx.Node], dict[str, torch.fx.Node]]:
"""
Unlift inputs referring to params/buffers/constants as getattr nodes in the
graph
"""
unlifted_name_to_node = {}
input_name_to_node = {}
placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
assert len(lifted_inputs) == len(placeholder_nodes)
for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs):
if lifted_node is None:
input_name_to_node[input_node.name] = input_node
else:
with gm.graph.inserting_after(input_node):
# It is fine to ignore this warning because
# it is guaranteed that we will populate this
# attr later.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
getattr_node = gm.graph.get_attr(lifted_node)
input_node.replace_all_uses_with(getattr_node)
metadata = input_node.meta
gm.graph.erase_node(input_node)
getattr_node.meta = metadata
getattr_node.meta["from_node"] = [
NodeSource(
input_node,
"ExportedProgram.module().unlift()",
[NodeSourceAction.CREATE, NodeSourceAction.REPLACE],
)
]
unlifted_name_to_node[lifted_node] = getattr_node
return unlifted_name_to_node, input_name_to_node
def _insert_copy_for_mutations(
gm: torch.fx.GraphModule,
mutated_outputs: Sequence[Optional[str]],
unlifted_name_to_node: dict[str, torch.fx.Node],
input_name_to_node: dict[str, torch.fx.Node],
) -> None:
"""
Find the all the buffers and inputs that were mutated and insert copy_
operators to reflect mutations.
"""
output_node = gm.graph.output_node()
outputs = pytree.tree_flatten(output_node.args)[0]
assert len(outputs) == len(mutated_outputs)
user_output_nodes = []
return_nodes_to_copy = {}
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
if mutated_node_name is None:
user_output_nodes.append(return_node)
continue
if mutated_node_name in unlifted_name_to_node:
mutated_node = unlifted_name_to_node[mutated_node_name]
elif mutated_node_name in input_name_to_node:
mutated_node = input_name_to_node[mutated_node_name]
else:
raise RuntimeError(
f"Could not find {mutated_node_name} in either buffer or input nodes"
)
with gm.graph.inserting_before(output_node):
copy_node = gm.graph.call_function(
torch.ops.aten.copy_.default, (mutated_node, return_node)
)
return_nodes_to_copy[return_node] = copy_node
output_args = tuple(
return_nodes_to_copy.get(node, node) for node in user_output_nodes
)
with gm.graph.inserting_before(output_node):
# Only return user outputs
new_output = gm.graph.output(output_args)
output_node.replace_all_uses_with(new_output)
gm.graph.erase_node(output_node)
new_output.name = output_node.name
new_output.meta.update(output_node.meta)
new_output.meta["from_node"] = [
NodeSource(
output_node,
"ExportedProgram.module().unlift()",
[NodeSourceAction.CREATE, NodeSourceAction.REPLACE],
)
]
def _get_codegen(
in_spec: pytree.TreeSpec,
out_spec: Optional[pytree.TreeSpec],
forward_arg_names: Optional[list[str]] = None,
) -> _PyTreeCodeGen:
"""
Create the codegen for the graph module based on the in/out specs
"""
if forward_arg_names:
names = forward_arg_names
elif (
in_spec.type is tuple
and in_spec.num_children == 2
and in_spec.children_specs[0].type is tuple
and in_spec.children_specs[1].type is dict
):
# if in_spec contains the args (tuple) and kwargs (dict)
names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
# add kwarg names
names.extend(in_spec.children_specs[1].context)
else:
names = [f"arg_{i}" for i in range(in_spec.num_children)]
return _PyTreeCodeGen(
_PyTreeInfo(
names,
in_spec,
out_spec,
)
)
def _unlift(
gm: torch.fx.GraphModule,
lifted_inputs: Sequence[Optional[str]],
mutated_outputs: Sequence[Optional[str]],
in_spec: pytree.TreeSpec,
out_spec: Optional[pytree.TreeSpec],
forward_arg_names: Optional[list[str]] = None,
):
"""
Args:
lifted_inputs: A list matching the graph module's input nodes. For
an input node that is referring to a lifted parameter/buffer, this
list will contain the fqn the corresponding attribute. Otherwise, this
list will contain None. This is used to unlift the lifted parameters as
get_attr nodes.
mutated_outputs: A list matching the graph module's output nodes. For
an output node that is referring to a mutated buffer or user input, this
list will contain the name of the corresponding buffer or user input
that needs to be mutated. Otherwise, this list will contain None. This
is used to re-insert an inplace copy_ operator to copy the mutated
values back to the original node.
"""
unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr(
gm, lifted_inputs
)
_insert_copy_for_mutations(
gm, mutated_outputs, unlifted_name_to_node, input_name_to_node
)
gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names)
gm.graph.lint()
gm.recompile()
return gm
def _register_attrs_to_new_gm(
new_gm: torch.fx.GraphModule,
graph_signature: ExportGraphSignature,
state_dict: dict[str, Any],
constants: dict[str, Any],
) -> None:
non_persistent_buffers = set(graph_signature.non_persistent_buffers)
for name in graph_signature.buffers:
if name in non_persistent_buffers:
persistent = False
value = constants[name]
else:
persistent = True
value = state_dict[name]
_assign_attr(
value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent
)
for name in graph_signature.parameters:
value = state_dict[name]
_assign_attr(
value,
new_gm,
name,
attr_kind=_AttrKind.PARAMETER,
)
# Technically this doesn't account for the aliased multiple constants but
# it is ok because we have a separate pass later in the stack that populates
# the final gm.
for name in chain(
graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants
):
value = constants[name]
_assign_attr(
value,
new_gm,
name,
attr_kind=_AttrKind.CONSTANT,
)
class _StatefulGraphModuleFactory(type):
"""
Metaclass that ensures a private constructor for _StatefulGraphModule
"""
def __call__(cls, *args, **kwargs):
raise TypeError(
f"{cls.__module__}.{cls.__qualname__} has no public constructor. "
)
def _create(cls, root, graph, range_constraints=None):
return super().__call__(
root,
graph,
range_constraints=range_constraints,
)
class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory):
def __init__(self, root, graph, range_constraints=None):
super().__init__(root, graph)
# Need to fix up non-persistent buffers.
self.range_constraints = range_constraints or []
self.validate_inputs = True
def _create_stateful_graph_module(
plain_graph_module: torch.fx.GraphModule,
range_constraints,
ep: ExportedProgram,
) -> _StatefulGraphModule:
stateful_gm = _StatefulGraphModule._create(
plain_graph_module,
plain_graph_module.graph,
range_constraints=range_constraints,
)
module_types = _get_graph_inputs_of_type_nn_module(ep.example_inputs)
stateful_gm.register_forward_pre_hook(
lambda *args, **kwargs: _enter_enable_graph_inputs_of_type_nn_module(
module_types
)
)
stateful_gm.register_forward_pre_hook(
_check_input_constraints_pre_hook, with_kwargs=True
)
stateful_gm.register_forward_hook(
lambda *args, **kwargs: _exit_enable_graph_inputs_of_type_nn_module(
module_types
),
always_call=True,
)
# When we have a constant that has requires_grad=True, we need to detach it
# when we unlift as the tensors that require gradients should be registered
# via parameters. But this is problematic when we have aliasing two constants
# because when we call detach, they will become different tensors. This dict
# keeps track of this logic.
original_tensor_to_detached_tensor = {}
# Fix up lifted tensor constants.
# fx.GraphModule() constructor silently turns a constant attribute of plain_graph_module
# into a buffer in stateful_gm and creates an inconsistency with graph_signature.
# We fix this by de-registering these buffers in lifted_tensor_constants
# and call _assign_attr(attr_kind=CONSTANT) to register them as constants.
for constant_fqn in ep.graph_signature.lifted_tensor_constants:
# Sometimes, the constant can require gradient, this is probably a bug in user code,
# e.g. `self.const = torch.randn(2, 2, requires_grad=True)`.
# We call detach on the constant_val since they're tensor constants and we don't need to
# compute their gradients anyway.
# Users should properly register it as parameter if they want it to require gradient.
buffer = stateful_gm.get_buffer(constant_fqn)
if buffer.requires_grad:
warnings.warn(
f"A model attribute `{constant_fqn}` requires gradient. "
f"but it's not properly registered as a parameter. "
f"torch.export will detach it and treat it as a constant tensor "
f"but please register it as parameter instead."
)
detached_buffer = buffer.detach()
original_tensor_to_detached_tensor[buffer] = detached_buffer
buffer = detached_buffer
*prefix, field = constant_fqn.rsplit(".")
submod = torch.fx.graph_module._get_attr_via_attr_list(stateful_gm, prefix)
delattr(submod, field)
_assign_attr(buffer, stateful_gm, constant_fqn, attr_kind=_AttrKind.CONSTANT)
# Constants are not preserved well when we create a new GraphModule unlike param/buffers
for const_name, value in ep.constants.items():
if not torch.fx.graph_module._has_attr(stateful_gm, const_name):
if isinstance(value, torch.Tensor):
if value.requires_grad:
warnings.warn(
f"A model attribute `{const_name}` requires gradient "
f"but it's not properly registered as a parameter. "
f"torch.export will detach it and treat it as a constant tensor "
f"but please register it as parameter instead."
)
if value in original_tensor_to_detached_tensor:
value = original_tensor_to_detached_tensor[value]
else:
detached_value = value.detach()
original_tensor_to_detached_tensor[value] = detached_value
value = detached_value
_assign_attr(
value,
stateful_gm,
const_name,
attr_kind=_AttrKind.CONSTANT,
)
# Fix up non-persistent buffers. torch.fx does not distinguish between
# persistent and non-persistent buffers, so we must restore that distinction
# here.
for buffer in ep.graph_signature.non_persistent_buffers:
_assign_attr(
plain_graph_module.get_buffer(buffer),
stateful_gm,
buffer,
attr_kind=_AttrKind.BUFFER,
persistent=False,
)
return stateful_gm
def _get_input_paths(example_inputs, signature):
"""
Generate paths of placeholders, needed for generating the guards function.
NOTE: Here we make use of the example inputs used for export as well as
the signature of the unlifted graph module (not preserved by export).
"""
args, kwargs = example_inputs
binded = signature.bind(*args, **kwargs)
binded.apply_defaults()
ctx = binded.arguments
flat_example_inputs_with_paths = pytree.tree_leaves_with_path(ctx)
return [path for path, _ in flat_example_inputs_with_paths]
def _replace_sources(result_str: str, flat_input_paths: list[Any]):
"""
Given user specified input paths, maybe fix up the guard string
to reflect user path instead of tracer path.
"""
name_mapping = {}
for idx, path in enumerate(flat_input_paths):
name_mapping[f"L['flat_args'][{idx}]"] = f"L{pytree.keystr(path)}"
replace = result_str
for key, val in name_mapping.items():
replace = replace.replace(key, val)
return replace
def _get_input_guards_for_graph(
placeholders: list[torch.fx.Node],
range_constraints: dict[sympy.Symbol, ValueRanges],
paths_for_placeholders: list[pytree.KeyPath],
):
"""
Guards generated by the tracer include conditions observed in code, but
but do not include some additional checks we typically do in export.
For example, when dynamic shapes get specialized, are specified to be
within a range, or are specified to be in some equational relation,
corresponding input invalidation is done within a pre_hook, specifically,
`_check_input_constraints_for_graph`.
Here we generate guards corresponding to the checks that happen in
`_check_input_constraints_for_graph`, and add them to the guards already
generated by the tracer. In the future, it may be worthwhile to separate
them so that we can allow clients to turn off one but not the other.
(Looking at you, AOTI.)
NOTE: We should eventually reconcile this logic with `build_guards` that
is used by AOT Precompile.
"""
deferred_expressions = []
new_guards_code = []
sources: dict[sympy.Expr, str] = {}
def handle_symint(expr, src):
if len(expr.free_symbols) == 1:
# complex equations (e.g., involving derived dims) need to
# handled later, since we may not have enough information
# just as we are passing through the placeholders in order
deferred_expressions.append((src, expr))
if expr in sources:
# expressions that appear in multiple sources should force
# inputs corresponding to those sources to be equal
# e.g., x.shape[0] == y.shape[1]
orig_src = sources[expr]
new_guards_code.append(f"{src} == {orig_src}")
else:
sources[expr] = src
# process value ranges as elsewhere in export
min_val, max_val = _convert_range_to_int(range_constraints[expr])
if min_val > 2:
new_guards_code.append(f"{src} >= {min_val}")
if max_val < math.inf:
new_guards_code.append(f"{src} <= {max_val}")
for placeholder, path in zip(placeholders, paths_for_placeholders):
src = "L" + pytree.keystr(path)
meta = placeholder.meta["val"]
# specializations
if isinstance(meta, int):
new_guards_code.append(f"{src} == {meta}")
if isinstance(meta, float):
if meta == math.inf:
new_guards_code.append(f"{src} == math.inf")
elif meta == -math.inf:
new_guards_code.append(f"{src} == -math.inf")
else:
new_guards_code.append(f"{src} == {meta}")
elif isinstance(meta, str):
new_guards_code.append(f"{src} == '{meta}'")
# range constraints and equalities
elif isinstance(meta, torch.SymInt) and meta.node.expr in range_constraints:
handle_symint(meta.node.expr, src)
elif isinstance(meta, torch.Tensor):
for i, dim in enumerate(meta.shape):
src = "L" + pytree.keystr(path) + f".size()[{i}]"
if isinstance(dim, int):
# specializations
new_guards_code.append(f"{src} == {dim}")
elif (
isinstance(dim, torch.SymInt) and dim.node.expr in range_constraints
):
# range constraints and equalities
handle_symint(dim.node.expr, src)
unification_map: dict[sympy.Symbol, sympy.Expr] = {}
py_printer = torch.utils._sympy.printers.PythonPrinter()
# process complex equations (e.g., involving derived dims)
for src, expr in deferred_expressions:
# we know this is the only symbol in expr (see check above)
symbol = next(iter(expr.free_symbols))
if symbol in sources:
# if s0 is already known to be directly sourced from inputs,
# e.g., z.shape[2], we do not need to do anything further
# (assume we have already processed constraints on s0 above)
continue
# otherwise s0 has some "hidden" source like 'dim'
# example: src = y.shape[1], expr = s0 + 1
if symbol in unification_map:
# suppose that we already know that s0 = x.shape[0] * 2
# so we can emit the guard: x.shape[0] * 2 + 1 = y.shape[1]
substitution = expr.subs(unification_map)
new_guards_code.append(
py_printer.doprint(sympy.Eq(substitution, sympy.Symbol(src)))
)
else:
# we do not yet know what s0 is, but given s0 + 1 = y.shape[1],
# we can solve for s0...now knowing that s0 = y.shape[1] - 1
solution = try_solve(sympy.Eq(expr, sympy.Symbol(src)), symbol)
if solution is not None:
definition = solution[1]
unification_map[symbol] = definition
return new_guards_code
def _ok_to_generate_guards_fn():
patterns = [
"executorch",
"modai",
"on_device_ai",
"torchao",
]
# force check_guards=False for files matching `patterns`
# because they have too many calls to .module() and
# do not like any call modules in the graph
# TODO: fix these files to handle guard fns
frame = inspect.currentframe()
while frame is not None:
if any(path in frame.f_code.co_filename for path in patterns):
return False
frame = frame.f_back
return True
def _unlift_exported_program_lifted_states(
ep: ExportedProgram, check_guards=True
) -> torch.fx.GraphModule:
check_guards = check_guards and _ok_to_generate_guards_fn()
# TODO T206340015
if ep.verifiers[0].dialect != "TRAINING":
ep = _remove_effect_tokens(ep)
new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
_register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
forward_arg_names = (
sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None
)
lifted_inputs: list[Optional[str]] = [
(
in_spec.target
if in_spec.kind
in (
InputKind.BUFFER,
InputKind.CONSTANT_TENSOR,
InputKind.PARAMETER,
InputKind.CUSTOM_OBJ,
)
else None
)
for in_spec in ep.graph_signature.input_specs
]
mutated_outputs: list[Optional[str]] = [
(
out_spec.target
if out_spec.kind
in (
OutputKind.BUFFER_MUTATION,
OutputKind.USER_INPUT_MUTATION,
OutputKind.PARAMETER_MUTATION,
)
else None
)
for out_spec in ep.graph_signature.output_specs
]
source_node_dict = {
node.name: node for node in ep.graph.nodes if node.op != "placeholder"
}
# placeholder node name might change after deepcopy
placeholder_source_node_dict = {
node.target: node for node in ep.graph.nodes if node.op == "placeholder"
}
for node in new_gm.graph.nodes:
source_node = None
if node.op == "placeholder":
source_node = placeholder_source_node_dict.get(node.target)
else:
source_node = source_node_dict.get(node.name)
node.meta["from_node"] = [
NodeSource(
source_node,
"ExportedProgram.module()",
NodeSourceAction.CREATE,
)
]
assert ep.call_spec.in_spec is not None
new_gm = _unlift(
new_gm,
lifted_inputs,
mutated_outputs,
ep.call_spec.in_spec,
ep.call_spec.out_spec,
forward_arg_names=forward_arg_names,
)
unlift_gm = _create_stateful_graph_module(new_gm, ep.range_constraints, ep)
unlift_gm.meta.update(ep.graph_module.meta)
# create a _guards_fn submodule and insert a call to it after placeholders
graph = unlift_gm.graph
placeholders = graph.find_nodes(op="placeholder")
if check_guards and placeholders and ep.example_inputs:
sig = inspect.signature(unlift_gm.forward)
input_paths = _get_input_paths(
ep.example_inputs,
sig,
)
# TODO (tmanlaibaatar)
# This is band-aid solution to export new tracer replacing
# shape env sources to flat_args. The real fix should be replacing
# shape env sources to original user sources but this is quite
# involved because you need to carefully construct new sources using
# dynamo and replace all instances of it inside shape env. But it is
# lot easier to manipulate after we turn them into strings and only
# time we use these guards is during retracing or running exported program,
# so it is probably ok to have "not useful" guards on ep for now.
ep_guards = []
for guard in ep._guards_code:
ep_guards.append(_replace_sources(guard, input_paths))
guards_code = _get_input_guards_for_graph(
placeholders, ep.range_constraints, input_paths
)
ep_guards_code = _force_ep_signature_match(ep._guards_code, input_paths)
ep_guards_code = _force_gm_signature_match(ep_guards_code, sig)
guards_code.extend(ep_guards_code)
unlift_gm._guards_fn = _convert_guards_code_to_fn(guards_code, input_paths)
root_nn_module_stack = torch.fx._utils.first_call_function_nn_module_stack(
graph
)
with graph.inserting_after(placeholders[-1]):
node = graph.call_module("_guards_fn", tuple(placeholders))
node.meta["nn_module_stack"] = root_nn_module_stack
unlift_gm.recompile()
return unlift_gm
class GuardsFn(torch.nn.Module):
"""
Module class for guard functions.
"""
def forward(self, *args):
pass