mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031 Approved by: https://github.com/mlazos
871 lines
32 KiB
Python
871 lines
32 KiB
Python
# mypy: allow-untyped-defs
|
|
import copy
|
|
import inspect
|
|
import math
|
|
import warnings
|
|
from collections.abc import Sequence
|
|
from itertools import chain
|
|
from typing import Any, Optional
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._export.non_strict_utils import (
|
|
_enter_enable_graph_inputs_of_type_nn_module,
|
|
_exit_enable_graph_inputs_of_type_nn_module,
|
|
_get_graph_inputs_of_type_nn_module,
|
|
)
|
|
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
|
|
_convert_range_to_int,
|
|
)
|
|
from torch._export.utils import _check_input_constraints_for_graph
|
|
from torch.export.unflatten import _assign_attr, _AttrKind
|
|
from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
|
|
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
|
from torch.fx.traceback import NodeSource, NodeSourceAction
|
|
from torch.utils._sympy.solve import try_solve
|
|
from torch.utils._sympy.value_ranges import ValueRanges
|
|
|
|
from ._remove_effect_tokens_pass import _remove_effect_tokens
|
|
from ._tree_utils import reorder_kwargs
|
|
from .exported_program import (
|
|
ExportedProgram,
|
|
ExportGraphSignature,
|
|
InputKind,
|
|
OutputKind,
|
|
)
|
|
|
|
|
|
def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool:
|
|
"""
|
|
Refinement of TreeSpec.__eq__ where, e.g., torch.Size(...) matches tuple(...).
|
|
See _pytree_subclasses_that_lose_info in proxy_tensor.py for more details.
|
|
"""
|
|
|
|
def _normalize_type(t):
|
|
return str(_pytree_subclasses_that_lose_info.get(t, t))
|
|
|
|
def _match_normalized_structure(a, b):
|
|
if a is b:
|
|
return True
|
|
if _normalize_type(a.type) != _normalize_type(b.type):
|
|
return False
|
|
if a.type is dict and b.type is dict:
|
|
# in the case of dict, the context is list of keys and we allow the keys to be in any order
|
|
if set(a.context) != set(b.context):
|
|
return False
|
|
elif a.context != b.context:
|
|
return False
|
|
if len(a.children_specs) != len(b.children_specs):
|
|
return False
|
|
return all(
|
|
_match_normalized_structure(a, b)
|
|
for a, b in zip(a.children_specs, b.children_specs)
|
|
)
|
|
|
|
return _match_normalized_structure(self, other)
|
|
|
|
|
|
def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> list:
|
|
reordered_kwargs = reorder_kwargs(kwargs, in_spec)
|
|
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
|
|
(args, reordered_kwargs)
|
|
)
|
|
|
|
if not eq_spec(received_spec, in_spec):
|
|
raise ValueError( # noqa: B904
|
|
"Trying to flatten user inputs with exported input tree spec: \n"
|
|
f"{in_spec}\n"
|
|
"but actually got inputs with tree spec of: \n"
|
|
f"{received_spec}.\n"
|
|
"Please check that the inputs have the same number and type of "
|
|
"args and kwargs as the ones you used when tracing."
|
|
)
|
|
|
|
return flat_args_with_path
|
|
|
|
|
|
def _force_ep_signature_match(ep_guards_code: list[str], input_paths):
|
|
# TODO (tmanlaibaatar)
|
|
# This is band-aid solution to export new tracer replacing
|
|
# shape env sources to flat_args. The real fix should be replacing
|
|
# shape env sources to original user sources but this is quite
|
|
# involved because you need to carefully construct new sources using
|
|
# dynamo and replace all instances of it inside shape env. But it is
|
|
# lot easier to manipulate after we turn them into strings and only
|
|
# time we use these guards is during retracing or running exported program,
|
|
# so it is probably ok to have "not useful" guards on ep for now.
|
|
name_mapping = {}
|
|
for idx, path in enumerate(input_paths):
|
|
name_mapping[f"L['flat_args'][{idx}]"] = f"L{pytree.keystr(path)}"
|
|
|
|
new_guards_code = []
|
|
for guard in ep_guards_code:
|
|
for old_name, new_name in name_mapping.items():
|
|
guard = guard.replace(old_name, new_name)
|
|
new_guards_code.append(guard)
|
|
|
|
return new_guards_code
|
|
|
|
|
|
def _force_gm_signature_match(ep_guards_code: list[str], signature):
|
|
"""
|
|
The signature of the originally exported module may not match
|
|
the signature of the unlifted graph module extracted from the
|
|
exported program. The guards code extracted from the exported
|
|
program is based on the former, but the generated guards fn is
|
|
based on the latter; thus we need to reconcile any such diff.
|
|
"""
|
|
|
|
import re
|
|
|
|
# Handle case where signatures may differ in var args.
|
|
orig_arg_names = set()
|
|
for g in ep_guards_code:
|
|
# match substrings of the form L['<name>'][<number>]
|
|
orig_arg_names.update(re.findall(r"L\[\'([^\']+)\'\]\[([0-9]+)\]", g))
|
|
|
|
sig_arg_names = set()
|
|
for n in signature.parameters:
|
|
# match substrings of the form <name>_<number>
|
|
sig_arg_names.update(re.findall(r"(.+)_([0-9]+)", n))
|
|
|
|
# replace L['<name>'][<number>] with L['<name>_<number>']
|
|
new_guards_code = ep_guards_code
|
|
for match in orig_arg_names:
|
|
if match in sig_arg_names:
|
|
base, idx = match
|
|
new_guards_code = [
|
|
g.replace(f"L['{base}'][{idx}]", f"L['{base}_{idx}']")
|
|
for g in new_guards_code
|
|
]
|
|
|
|
return new_guards_code
|
|
|
|
|
|
def _convert_guards_code_to_fn(
|
|
guards_code: list[str],
|
|
paths_of_placeholders: list[pytree.KeyPath],
|
|
):
|
|
"""
|
|
Generates Python code given guards code and paths of placeholders.
|
|
We assume that, based on source information,
|
|
- the tracer generates the guards code
|
|
- the input spec generates the paths of placeholders.
|
|
|
|
Example:
|
|
|
|
Suppose we are given the guards code "L['z']['k'].size()[1] == 3"
|
|
and we are given that ['z']['k'] is the path of placeholder #2.
|
|
Then we will generate:
|
|
```
|
|
torch._assert(
|
|
args[2].size()[0] == 3,
|
|
"Guard failed: z['k'].size()[0] == 3",
|
|
)
|
|
```
|
|
|
|
FAQ: Why do we generate code based on (flattened) args instead of
|
|
the original (unflattened) inputs? Because this would require
|
|
inserting an additional pytree.unflatten call in our graph.
|
|
|
|
FAQ: Why do we not emit RuntimeError on guard failure as we used to?
|
|
Because it is inconvenient :/, get used to AssertionError instead.
|
|
"""
|
|
|
|
import ast
|
|
|
|
from torch.fx.experimental.symbolic_shapes import SYMPY_INTERP
|
|
|
|
actual_guards_code = []
|
|
shadow_guards_code = []
|
|
for c in guards_code:
|
|
a, s = c, c
|
|
for idx, path in enumerate(paths_of_placeholders):
|
|
# e.g., replace L['z']['k'] with args[2] for Python code (actual)
|
|
a = a.replace("L" + pytree.keystr(path), f"args[{idx}]")
|
|
# e.g., replace L['z']['k'] with z['k'] for error message (shadow)
|
|
s = s.replace(
|
|
"L" + pytree.keystr(path),
|
|
path[0].key + pytree.keystr(path[1:]), # type: ignore[attr-defined]
|
|
)
|
|
actual_guards_code.append(a)
|
|
shadow_guards_code.append(s.replace("\n", ""))
|
|
|
|
# generate function code as str
|
|
code_str = "\ndef _(*args):\n"
|
|
for actual, shadow in zip(actual_guards_code, shadow_guards_code):
|
|
# printing guards code may potentially introduce redundant parens;
|
|
# we can normalize them out for readability by parsing/unparsing
|
|
# NOTE: this is not necessary for correctness, just deemed desirable
|
|
_shadow = ast.unparse(ast.parse(shadow, mode="eval"))
|
|
# actual code and shadow error message
|
|
code_str += f' torch._assert({actual}, "Guard failed: {_shadow}")\n'
|
|
code_str += " return\n"
|
|
|
|
# populate namespace with sympy globals, materialize function (named `_`)
|
|
namespace = {**SYMPY_INTERP}
|
|
exec(code_str, namespace)
|
|
|
|
# create and return a module whose forward is the materialized function
|
|
# NOTE: we want Dynamo to trace through this module, to repopulate guards:
|
|
# otherwise we would lose them when retracing
|
|
# NOTE: calling this module will be a side effect (no users): so it must
|
|
# be marked impure to avoid being not cleaned up by DCE
|
|
guards_fn = GuardsFn()
|
|
guards_fn.forward = torch._dynamo.dont_skip_tracing(namespace["_"]) # type: ignore[call-overload, method-assign]
|
|
guards_fn._is_impure = True # type: ignore[assignment]
|
|
return guards_fn
|
|
|
|
|
|
@torch._dynamo.disable
|
|
def _check_input_constraints_for_module(self, args, kwargs):
|
|
flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
|
|
_check_input_constraints_for_graph(
|
|
self.graph.find_nodes(op="placeholder"),
|
|
flat_args_with_path,
|
|
self.range_constraints,
|
|
)
|
|
|
|
|
|
def _check_input_constraints_pre_hook(self, args, kwargs):
|
|
# preserve current behavior for clients that do not want any validation
|
|
if not self.validate_inputs:
|
|
return
|
|
|
|
# when a guards function exists, assume that the graph does calls it!
|
|
# so we do not need to check input constraints...but we still want
|
|
# to check inputs match, otherwise we'd get obscure pytree errors
|
|
if hasattr(self, "_guards_fn"):
|
|
_check_inputs_match(args, kwargs, self._in_spec)
|
|
return
|
|
|
|
# NOTE: for some reason, Dynamo is tracing into this, we should see why and
|
|
# put compile at the right place. Until then, we can skip the input
|
|
# constraint checks.
|
|
if not torch.compiler.is_dynamo_compiling():
|
|
_check_input_constraints_for_module(self, args, kwargs)
|
|
|
|
|
|
def _unlift_inputs_as_getattr(
|
|
gm: torch.fx.GraphModule,
|
|
lifted_inputs: Sequence[Optional[str]],
|
|
) -> tuple[dict[str, torch.fx.Node], dict[str, torch.fx.Node]]:
|
|
"""
|
|
Unlift inputs referring to params/buffers/constants as getattr nodes in the
|
|
graph
|
|
"""
|
|
unlifted_name_to_node = {}
|
|
input_name_to_node = {}
|
|
|
|
placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
|
|
assert len(lifted_inputs) == len(placeholder_nodes)
|
|
for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs):
|
|
if lifted_node is None:
|
|
input_name_to_node[input_node.name] = input_node
|
|
|
|
else:
|
|
with gm.graph.inserting_after(input_node):
|
|
# It is fine to ignore this warning because
|
|
# it is guaranteed that we will populate this
|
|
# attr later.
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
getattr_node = gm.graph.get_attr(lifted_node)
|
|
input_node.replace_all_uses_with(getattr_node)
|
|
metadata = input_node.meta
|
|
gm.graph.erase_node(input_node)
|
|
getattr_node.meta = metadata
|
|
getattr_node.meta["from_node"] = [
|
|
NodeSource(
|
|
input_node,
|
|
"ExportedProgram.module().unlift()",
|
|
[NodeSourceAction.CREATE, NodeSourceAction.REPLACE],
|
|
)
|
|
]
|
|
unlifted_name_to_node[lifted_node] = getattr_node
|
|
|
|
return unlifted_name_to_node, input_name_to_node
|
|
|
|
|
|
def _insert_copy_for_mutations(
|
|
gm: torch.fx.GraphModule,
|
|
mutated_outputs: Sequence[Optional[str]],
|
|
unlifted_name_to_node: dict[str, torch.fx.Node],
|
|
input_name_to_node: dict[str, torch.fx.Node],
|
|
) -> None:
|
|
"""
|
|
Find the all the buffers and inputs that were mutated and insert copy_
|
|
operators to reflect mutations.
|
|
"""
|
|
output_node = gm.graph.output_node()
|
|
outputs = pytree.tree_flatten(output_node.args)[0]
|
|
assert len(outputs) == len(mutated_outputs)
|
|
|
|
user_output_nodes = []
|
|
return_nodes_to_copy = {}
|
|
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
|
|
if mutated_node_name is None:
|
|
user_output_nodes.append(return_node)
|
|
continue
|
|
|
|
if mutated_node_name in unlifted_name_to_node:
|
|
mutated_node = unlifted_name_to_node[mutated_node_name]
|
|
elif mutated_node_name in input_name_to_node:
|
|
mutated_node = input_name_to_node[mutated_node_name]
|
|
else:
|
|
raise RuntimeError(
|
|
f"Could not find {mutated_node_name} in either buffer or input nodes"
|
|
)
|
|
|
|
with gm.graph.inserting_before(output_node):
|
|
copy_node = gm.graph.call_function(
|
|
torch.ops.aten.copy_.default, (mutated_node, return_node)
|
|
)
|
|
return_nodes_to_copy[return_node] = copy_node
|
|
|
|
output_args = tuple(
|
|
return_nodes_to_copy.get(node, node) for node in user_output_nodes
|
|
)
|
|
with gm.graph.inserting_before(output_node):
|
|
# Only return user outputs
|
|
new_output = gm.graph.output(output_args)
|
|
output_node.replace_all_uses_with(new_output)
|
|
gm.graph.erase_node(output_node)
|
|
new_output.name = output_node.name
|
|
new_output.meta.update(output_node.meta)
|
|
new_output.meta["from_node"] = [
|
|
NodeSource(
|
|
output_node,
|
|
"ExportedProgram.module().unlift()",
|
|
[NodeSourceAction.CREATE, NodeSourceAction.REPLACE],
|
|
)
|
|
]
|
|
|
|
|
|
def _get_codegen(
|
|
in_spec: pytree.TreeSpec,
|
|
out_spec: Optional[pytree.TreeSpec],
|
|
forward_arg_names: Optional[list[str]] = None,
|
|
) -> _PyTreeCodeGen:
|
|
"""
|
|
Create the codegen for the graph module based on the in/out specs
|
|
"""
|
|
if forward_arg_names:
|
|
names = forward_arg_names
|
|
elif (
|
|
in_spec.type is tuple
|
|
and in_spec.num_children == 2
|
|
and in_spec.children_specs[0].type is tuple
|
|
and in_spec.children_specs[1].type is dict
|
|
):
|
|
# if in_spec contains the args (tuple) and kwargs (dict)
|
|
names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
|
|
# add kwarg names
|
|
names.extend(in_spec.children_specs[1].context)
|
|
else:
|
|
names = [f"arg_{i}" for i in range(in_spec.num_children)]
|
|
|
|
return _PyTreeCodeGen(
|
|
_PyTreeInfo(
|
|
names,
|
|
in_spec,
|
|
out_spec,
|
|
)
|
|
)
|
|
|
|
|
|
def _unlift(
|
|
gm: torch.fx.GraphModule,
|
|
lifted_inputs: Sequence[Optional[str]],
|
|
mutated_outputs: Sequence[Optional[str]],
|
|
in_spec: pytree.TreeSpec,
|
|
out_spec: Optional[pytree.TreeSpec],
|
|
forward_arg_names: Optional[list[str]] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
lifted_inputs: A list matching the graph module's input nodes. For
|
|
an input node that is referring to a lifted parameter/buffer, this
|
|
list will contain the fqn the corresponding attribute. Otherwise, this
|
|
list will contain None. This is used to unlift the lifted parameters as
|
|
get_attr nodes.
|
|
|
|
mutated_outputs: A list matching the graph module's output nodes. For
|
|
an output node that is referring to a mutated buffer or user input, this
|
|
list will contain the name of the corresponding buffer or user input
|
|
that needs to be mutated. Otherwise, this list will contain None. This
|
|
is used to re-insert an inplace copy_ operator to copy the mutated
|
|
values back to the original node.
|
|
"""
|
|
unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr(
|
|
gm, lifted_inputs
|
|
)
|
|
_insert_copy_for_mutations(
|
|
gm, mutated_outputs, unlifted_name_to_node, input_name_to_node
|
|
)
|
|
gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names)
|
|
gm.graph.lint()
|
|
gm.recompile()
|
|
return gm
|
|
|
|
|
|
def _register_attrs_to_new_gm(
|
|
new_gm: torch.fx.GraphModule,
|
|
graph_signature: ExportGraphSignature,
|
|
state_dict: dict[str, Any],
|
|
constants: dict[str, Any],
|
|
) -> None:
|
|
non_persistent_buffers = set(graph_signature.non_persistent_buffers)
|
|
for name in graph_signature.buffers:
|
|
if name in non_persistent_buffers:
|
|
persistent = False
|
|
value = constants[name]
|
|
else:
|
|
persistent = True
|
|
value = state_dict[name]
|
|
_assign_attr(
|
|
value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent
|
|
)
|
|
for name in graph_signature.parameters:
|
|
value = state_dict[name]
|
|
_assign_attr(
|
|
value,
|
|
new_gm,
|
|
name,
|
|
attr_kind=_AttrKind.PARAMETER,
|
|
)
|
|
|
|
# Technically this doesn't account for the aliased multiple constants but
|
|
# it is ok because we have a separate pass later in the stack that populates
|
|
# the final gm.
|
|
for name in chain(
|
|
graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants
|
|
):
|
|
value = constants[name]
|
|
_assign_attr(
|
|
value,
|
|
new_gm,
|
|
name,
|
|
attr_kind=_AttrKind.CONSTANT,
|
|
)
|
|
|
|
|
|
class _StatefulGraphModuleFactory(type):
|
|
"""
|
|
Metaclass that ensures a private constructor for _StatefulGraphModule
|
|
"""
|
|
|
|
def __call__(cls, *args, **kwargs):
|
|
raise TypeError(
|
|
f"{cls.__module__}.{cls.__qualname__} has no public constructor. "
|
|
)
|
|
|
|
def _create(cls, root, graph, range_constraints=None):
|
|
return super().__call__(
|
|
root,
|
|
graph,
|
|
range_constraints=range_constraints,
|
|
)
|
|
|
|
|
|
class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory):
|
|
def __init__(self, root, graph, range_constraints=None):
|
|
super().__init__(root, graph)
|
|
# Need to fix up non-persistent buffers.
|
|
self.range_constraints = range_constraints or []
|
|
self.validate_inputs = True
|
|
|
|
|
|
def _create_stateful_graph_module(
|
|
plain_graph_module: torch.fx.GraphModule,
|
|
range_constraints,
|
|
ep: ExportedProgram,
|
|
) -> _StatefulGraphModule:
|
|
stateful_gm = _StatefulGraphModule._create(
|
|
plain_graph_module,
|
|
plain_graph_module.graph,
|
|
range_constraints=range_constraints,
|
|
)
|
|
|
|
module_types = _get_graph_inputs_of_type_nn_module(ep.example_inputs)
|
|
stateful_gm.register_forward_pre_hook(
|
|
lambda *args, **kwargs: _enter_enable_graph_inputs_of_type_nn_module(
|
|
module_types
|
|
)
|
|
)
|
|
stateful_gm.register_forward_pre_hook(
|
|
_check_input_constraints_pre_hook, with_kwargs=True
|
|
)
|
|
|
|
stateful_gm.register_forward_hook(
|
|
lambda *args, **kwargs: _exit_enable_graph_inputs_of_type_nn_module(
|
|
module_types
|
|
),
|
|
always_call=True,
|
|
)
|
|
|
|
# When we have a constant that has requires_grad=True, we need to detach it
|
|
# when we unlift as the tensors that require gradients should be registered
|
|
# via parameters. But this is problematic when we have aliasing two constants
|
|
# because when we call detach, they will become different tensors. This dict
|
|
# keeps track of this logic.
|
|
original_tensor_to_detached_tensor = {}
|
|
|
|
# Fix up lifted tensor constants.
|
|
# fx.GraphModule() constructor silently turns a constant attribute of plain_graph_module
|
|
# into a buffer in stateful_gm and creates an inconsistency with graph_signature.
|
|
# We fix this by de-registering these buffers in lifted_tensor_constants
|
|
# and call _assign_attr(attr_kind=CONSTANT) to register them as constants.
|
|
for constant_fqn in ep.graph_signature.lifted_tensor_constants:
|
|
# Sometimes, the constant can require gradient, this is probably a bug in user code,
|
|
# e.g. `self.const = torch.randn(2, 2, requires_grad=True)`.
|
|
# We call detach on the constant_val since they're tensor constants and we don't need to
|
|
# compute their gradients anyway.
|
|
# Users should properly register it as parameter if they want it to require gradient.
|
|
buffer = stateful_gm.get_buffer(constant_fqn)
|
|
if buffer.requires_grad:
|
|
warnings.warn(
|
|
f"A model attribute `{constant_fqn}` requires gradient. "
|
|
f"but it's not properly registered as a parameter. "
|
|
f"torch.export will detach it and treat it as a constant tensor "
|
|
f"but please register it as parameter instead."
|
|
)
|
|
detached_buffer = buffer.detach()
|
|
original_tensor_to_detached_tensor[buffer] = detached_buffer
|
|
buffer = detached_buffer
|
|
*prefix, field = constant_fqn.rsplit(".")
|
|
submod = torch.fx.graph_module._get_attr_via_attr_list(stateful_gm, prefix)
|
|
delattr(submod, field)
|
|
_assign_attr(buffer, stateful_gm, constant_fqn, attr_kind=_AttrKind.CONSTANT)
|
|
|
|
# Constants are not preserved well when we create a new GraphModule unlike param/buffers
|
|
for const_name, value in ep.constants.items():
|
|
if not torch.fx.graph_module._has_attr(stateful_gm, const_name):
|
|
if isinstance(value, torch.Tensor):
|
|
if value.requires_grad:
|
|
warnings.warn(
|
|
f"A model attribute `{const_name}` requires gradient "
|
|
f"but it's not properly registered as a parameter. "
|
|
f"torch.export will detach it and treat it as a constant tensor "
|
|
f"but please register it as parameter instead."
|
|
)
|
|
if value in original_tensor_to_detached_tensor:
|
|
value = original_tensor_to_detached_tensor[value]
|
|
else:
|
|
detached_value = value.detach()
|
|
original_tensor_to_detached_tensor[value] = detached_value
|
|
value = detached_value
|
|
_assign_attr(
|
|
value,
|
|
stateful_gm,
|
|
const_name,
|
|
attr_kind=_AttrKind.CONSTANT,
|
|
)
|
|
|
|
# Fix up non-persistent buffers. torch.fx does not distinguish between
|
|
# persistent and non-persistent buffers, so we must restore that distinction
|
|
# here.
|
|
for buffer in ep.graph_signature.non_persistent_buffers:
|
|
_assign_attr(
|
|
plain_graph_module.get_buffer(buffer),
|
|
stateful_gm,
|
|
buffer,
|
|
attr_kind=_AttrKind.BUFFER,
|
|
persistent=False,
|
|
)
|
|
|
|
return stateful_gm
|
|
|
|
|
|
def _get_input_paths(example_inputs, signature):
|
|
"""
|
|
Generate paths of placeholders, needed for generating the guards function.
|
|
|
|
NOTE: Here we make use of the example inputs used for export as well as
|
|
the signature of the unlifted graph module (not preserved by export).
|
|
"""
|
|
|
|
args, kwargs = example_inputs
|
|
binded = signature.bind(*args, **kwargs)
|
|
binded.apply_defaults()
|
|
ctx = binded.arguments
|
|
flat_example_inputs_with_paths = pytree.tree_leaves_with_path(ctx)
|
|
return [path for path, _ in flat_example_inputs_with_paths]
|
|
|
|
|
|
def _replace_sources(result_str: str, flat_input_paths: list[Any]):
|
|
"""
|
|
Given user specified input paths, maybe fix up the guard string
|
|
to reflect user path instead of tracer path.
|
|
"""
|
|
name_mapping = {}
|
|
for idx, path in enumerate(flat_input_paths):
|
|
name_mapping[f"L['flat_args'][{idx}]"] = f"L{pytree.keystr(path)}"
|
|
|
|
replace = result_str
|
|
for key, val in name_mapping.items():
|
|
replace = replace.replace(key, val)
|
|
return replace
|
|
|
|
|
|
def _get_input_guards_for_graph(
|
|
placeholders: list[torch.fx.Node],
|
|
range_constraints: dict[sympy.Symbol, ValueRanges],
|
|
paths_for_placeholders: list[pytree.KeyPath],
|
|
):
|
|
"""
|
|
Guards generated by the tracer include conditions observed in code, but
|
|
but do not include some additional checks we typically do in export.
|
|
For example, when dynamic shapes get specialized, are specified to be
|
|
within a range, or are specified to be in some equational relation,
|
|
corresponding input invalidation is done within a pre_hook, specifically,
|
|
`_check_input_constraints_for_graph`.
|
|
|
|
Here we generate guards corresponding to the checks that happen in
|
|
`_check_input_constraints_for_graph`, and add them to the guards already
|
|
generated by the tracer. In the future, it may be worthwhile to separate
|
|
them so that we can allow clients to turn off one but not the other.
|
|
(Looking at you, AOTI.)
|
|
|
|
NOTE: We should eventually reconcile this logic with `build_guards` that
|
|
is used by AOT Precompile.
|
|
"""
|
|
|
|
deferred_expressions = []
|
|
new_guards_code = []
|
|
sources: dict[sympy.Expr, str] = {}
|
|
|
|
def handle_symint(expr, src):
|
|
if len(expr.free_symbols) == 1:
|
|
# complex equations (e.g., involving derived dims) need to
|
|
# handled later, since we may not have enough information
|
|
# just as we are passing through the placeholders in order
|
|
deferred_expressions.append((src, expr))
|
|
if expr in sources:
|
|
# expressions that appear in multiple sources should force
|
|
# inputs corresponding to those sources to be equal
|
|
# e.g., x.shape[0] == y.shape[1]
|
|
orig_src = sources[expr]
|
|
new_guards_code.append(f"{src} == {orig_src}")
|
|
else:
|
|
sources[expr] = src
|
|
# process value ranges as elsewhere in export
|
|
min_val, max_val = _convert_range_to_int(range_constraints[expr])
|
|
if min_val > 2:
|
|
new_guards_code.append(f"{src} >= {min_val}")
|
|
if max_val < math.inf:
|
|
new_guards_code.append(f"{src} <= {max_val}")
|
|
|
|
for placeholder, path in zip(placeholders, paths_for_placeholders):
|
|
src = "L" + pytree.keystr(path)
|
|
meta = placeholder.meta["val"]
|
|
# specializations
|
|
if isinstance(meta, int):
|
|
new_guards_code.append(f"{src} == {meta}")
|
|
if isinstance(meta, float):
|
|
if meta == math.inf:
|
|
new_guards_code.append(f"{src} == math.inf")
|
|
elif meta == -math.inf:
|
|
new_guards_code.append(f"{src} == -math.inf")
|
|
else:
|
|
new_guards_code.append(f"{src} == {meta}")
|
|
elif isinstance(meta, str):
|
|
new_guards_code.append(f"{src} == '{meta}'")
|
|
# range constraints and equalities
|
|
elif isinstance(meta, torch.SymInt) and meta.node.expr in range_constraints:
|
|
handle_symint(meta.node.expr, src)
|
|
elif isinstance(meta, torch.Tensor):
|
|
for i, dim in enumerate(meta.shape):
|
|
src = "L" + pytree.keystr(path) + f".size()[{i}]"
|
|
if isinstance(dim, int):
|
|
# specializations
|
|
new_guards_code.append(f"{src} == {dim}")
|
|
elif (
|
|
isinstance(dim, torch.SymInt) and dim.node.expr in range_constraints
|
|
):
|
|
# range constraints and equalities
|
|
handle_symint(dim.node.expr, src)
|
|
|
|
unification_map: dict[sympy.Symbol, sympy.Expr] = {}
|
|
py_printer = torch.utils._sympy.printers.PythonPrinter()
|
|
|
|
# process complex equations (e.g., involving derived dims)
|
|
for src, expr in deferred_expressions:
|
|
# we know this is the only symbol in expr (see check above)
|
|
symbol = next(iter(expr.free_symbols))
|
|
if symbol in sources:
|
|
# if s0 is already known to be directly sourced from inputs,
|
|
# e.g., z.shape[2], we do not need to do anything further
|
|
# (assume we have already processed constraints on s0 above)
|
|
continue
|
|
|
|
# otherwise s0 has some "hidden" source like 'dim'
|
|
# example: src = y.shape[1], expr = s0 + 1
|
|
if symbol in unification_map:
|
|
# suppose that we already know that s0 = x.shape[0] * 2
|
|
# so we can emit the guard: x.shape[0] * 2 + 1 = y.shape[1]
|
|
substitution = expr.subs(unification_map)
|
|
new_guards_code.append(
|
|
py_printer.doprint(sympy.Eq(substitution, sympy.Symbol(src)))
|
|
)
|
|
else:
|
|
# we do not yet know what s0 is, but given s0 + 1 = y.shape[1],
|
|
# we can solve for s0...now knowing that s0 = y.shape[1] - 1
|
|
solution = try_solve(sympy.Eq(expr, sympy.Symbol(src)), symbol)
|
|
if solution is not None:
|
|
definition = solution[1]
|
|
unification_map[symbol] = definition
|
|
|
|
return new_guards_code
|
|
|
|
|
|
def _ok_to_generate_guards_fn():
|
|
patterns = [
|
|
"executorch",
|
|
"modai",
|
|
"on_device_ai",
|
|
"torchao",
|
|
]
|
|
# force check_guards=False for files matching `patterns`
|
|
# because they have too many calls to .module() and
|
|
# do not like any call modules in the graph
|
|
# TODO: fix these files to handle guard fns
|
|
frame = inspect.currentframe()
|
|
while frame is not None:
|
|
if any(path in frame.f_code.co_filename for path in patterns):
|
|
return False
|
|
frame = frame.f_back
|
|
|
|
return True
|
|
|
|
|
|
def _unlift_exported_program_lifted_states(
|
|
ep: ExportedProgram, check_guards=True
|
|
) -> torch.fx.GraphModule:
|
|
check_guards = check_guards and _ok_to_generate_guards_fn()
|
|
|
|
# TODO T206340015
|
|
if ep.verifiers[0].dialect != "TRAINING":
|
|
ep = _remove_effect_tokens(ep)
|
|
|
|
new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
|
|
_register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
|
|
forward_arg_names = (
|
|
sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None
|
|
)
|
|
lifted_inputs: list[Optional[str]] = [
|
|
(
|
|
in_spec.target
|
|
if in_spec.kind
|
|
in (
|
|
InputKind.BUFFER,
|
|
InputKind.CONSTANT_TENSOR,
|
|
InputKind.PARAMETER,
|
|
InputKind.CUSTOM_OBJ,
|
|
)
|
|
else None
|
|
)
|
|
for in_spec in ep.graph_signature.input_specs
|
|
]
|
|
|
|
mutated_outputs: list[Optional[str]] = [
|
|
(
|
|
out_spec.target
|
|
if out_spec.kind
|
|
in (
|
|
OutputKind.BUFFER_MUTATION,
|
|
OutputKind.USER_INPUT_MUTATION,
|
|
OutputKind.PARAMETER_MUTATION,
|
|
)
|
|
else None
|
|
)
|
|
for out_spec in ep.graph_signature.output_specs
|
|
]
|
|
|
|
source_node_dict = {
|
|
node.name: node for node in ep.graph.nodes if node.op != "placeholder"
|
|
}
|
|
# placeholder node name might change after deepcopy
|
|
placeholder_source_node_dict = {
|
|
node.target: node for node in ep.graph.nodes if node.op == "placeholder"
|
|
}
|
|
for node in new_gm.graph.nodes:
|
|
source_node = None
|
|
if node.op == "placeholder":
|
|
source_node = placeholder_source_node_dict.get(node.target)
|
|
else:
|
|
source_node = source_node_dict.get(node.name)
|
|
node.meta["from_node"] = [
|
|
NodeSource(
|
|
source_node,
|
|
"ExportedProgram.module()",
|
|
NodeSourceAction.CREATE,
|
|
)
|
|
]
|
|
|
|
assert ep.call_spec.in_spec is not None
|
|
new_gm = _unlift(
|
|
new_gm,
|
|
lifted_inputs,
|
|
mutated_outputs,
|
|
ep.call_spec.in_spec,
|
|
ep.call_spec.out_spec,
|
|
forward_arg_names=forward_arg_names,
|
|
)
|
|
unlift_gm = _create_stateful_graph_module(new_gm, ep.range_constraints, ep)
|
|
unlift_gm.meta.update(ep.graph_module.meta)
|
|
|
|
# create a _guards_fn submodule and insert a call to it after placeholders
|
|
graph = unlift_gm.graph
|
|
placeholders = graph.find_nodes(op="placeholder")
|
|
if check_guards and placeholders and ep.example_inputs:
|
|
sig = inspect.signature(unlift_gm.forward)
|
|
input_paths = _get_input_paths(
|
|
ep.example_inputs,
|
|
sig,
|
|
)
|
|
|
|
# TODO (tmanlaibaatar)
|
|
# This is band-aid solution to export new tracer replacing
|
|
# shape env sources to flat_args. The real fix should be replacing
|
|
# shape env sources to original user sources but this is quite
|
|
# involved because you need to carefully construct new sources using
|
|
# dynamo and replace all instances of it inside shape env. But it is
|
|
# lot easier to manipulate after we turn them into strings and only
|
|
# time we use these guards is during retracing or running exported program,
|
|
# so it is probably ok to have "not useful" guards on ep for now.
|
|
ep_guards = []
|
|
for guard in ep._guards_code:
|
|
ep_guards.append(_replace_sources(guard, input_paths))
|
|
|
|
guards_code = _get_input_guards_for_graph(
|
|
placeholders, ep.range_constraints, input_paths
|
|
)
|
|
|
|
ep_guards_code = _force_ep_signature_match(ep._guards_code, input_paths)
|
|
ep_guards_code = _force_gm_signature_match(ep_guards_code, sig)
|
|
guards_code.extend(ep_guards_code)
|
|
unlift_gm._guards_fn = _convert_guards_code_to_fn(guards_code, input_paths)
|
|
|
|
root_nn_module_stack = torch.fx._utils.first_call_function_nn_module_stack(
|
|
graph
|
|
)
|
|
with graph.inserting_after(placeholders[-1]):
|
|
node = graph.call_module("_guards_fn", tuple(placeholders))
|
|
node.meta["nn_module_stack"] = root_nn_module_stack
|
|
|
|
unlift_gm.recompile()
|
|
|
|
return unlift_gm
|
|
|
|
|
|
class GuardsFn(torch.nn.Module):
|
|
"""
|
|
Module class for guard functions.
|
|
"""
|
|
|
|
def forward(self, *args):
|
|
pass
|