mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145138 Approved by: https://github.com/bobrenjc93 ghstack dependencies: #145154
658 lines
25 KiB
Python
658 lines
25 KiB
Python
# mypy: allow-untyped-defs
|
|
import contextlib
|
|
import inspect
|
|
import logging
|
|
from collections import defaultdict
|
|
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._dynamo.source import (
|
|
AttrSource,
|
|
GetItemSource,
|
|
LocalSource,
|
|
TensorProperty,
|
|
TensorPropertySource,
|
|
)
|
|
from torch._dynamo.variables.builder import TrackedFake
|
|
from torch._export.passes.add_runtime_assertions_for_constraints_pass import InputDim
|
|
from torch._export.passes.lift_constants_pass import ConstantAttrMap
|
|
from torch._export.utils import _fakify_params_buffers
|
|
from torch._guards import Source
|
|
from torch._library.fake_class_registry import FakeScriptObject
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch.export import Constraint
|
|
from torch.export.dynamic_shapes import (
|
|
_check_dynamic_shapes,
|
|
_combine_args,
|
|
_DimHint,
|
|
_process_dynamic_shapes,
|
|
_RelaxedConstraint,
|
|
_tree_map_with_path,
|
|
)
|
|
from torch.export.graph_signature import CustomObjArgument
|
|
from torch.fx.experimental import _config as config
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
_find_user_code_frame,
|
|
_suggest_fixes_for_data_dependent_error_non_strict,
|
|
ConstraintViolationError,
|
|
DimDynamic,
|
|
EqualityConstraint,
|
|
GuardOnDataDependentSymNode,
|
|
RelaxedUnspecConstraint,
|
|
ShapeEnv,
|
|
StatelessSymbolicContext,
|
|
ValueRanges,
|
|
)
|
|
from torch.utils._pytree import (
|
|
GetAttrKey,
|
|
KeyPath,
|
|
MappingKey,
|
|
SequenceKey,
|
|
tree_map_with_path,
|
|
)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from sympy import Symbol
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def key_path_to_source(kp: KeyPath) -> Source:
|
|
"""
|
|
Given a key path, return the source for the key path.
|
|
"""
|
|
source: Source = LocalSource("args")
|
|
for k in kp:
|
|
if isinstance(k, SequenceKey):
|
|
source = GetItemSource(source, k.idx)
|
|
elif isinstance(k, MappingKey):
|
|
source = GetItemSource(source, k.key)
|
|
elif isinstance(k, GetAttrKey):
|
|
source = AttrSource(source, k.name)
|
|
else:
|
|
raise ValueError(f"Unknown KeyEntry {k}")
|
|
|
|
return source
|
|
|
|
|
|
def _is_constant_argument(t):
|
|
return t is None or isinstance(t, (int, float, bool, str))
|
|
|
|
|
|
def fakify(
|
|
mode: FakeTensorMode,
|
|
kp: KeyPath,
|
|
t: Any,
|
|
t_constraints: dict[int, dict[int, Constraint]],
|
|
sources: dict[tuple[int, int], list[Source]],
|
|
):
|
|
source = key_path_to_source(kp)
|
|
if _is_constant_argument(t) or isinstance(t, (torch.ScriptObject, torch.nn.Module)):
|
|
return t
|
|
|
|
if not isinstance(t, torch.Tensor):
|
|
raise ValueError(f"Unsupported input type {type(t)}")
|
|
n_dims = len(t.shape)
|
|
dynamic_sizes = []
|
|
constraint_sizes = [None] * n_dims
|
|
for i in range(n_dims):
|
|
if i in getattr(t, "_dynamo_weak_dynamic_indices", {}):
|
|
dynamic_sizes.append(DimDynamic.DYNAMIC)
|
|
elif i in getattr(t, "_dynamo_dynamic_indices", {}):
|
|
# bit annoying, but we need to replicate process in _dynamo/variables/builder.py
|
|
# where a RelaxedUnspecConstraint is created for Dim.DYNAMIC, so constraint violations
|
|
# are raised when specializing.
|
|
dynamic_sizes.append(DimDynamic.DYNAMIC)
|
|
constraint_sizes[i] = RelaxedUnspecConstraint(warn_only=False) # type: ignore[call-overload]
|
|
else:
|
|
dynamic_sizes.append(DimDynamic.STATIC)
|
|
symbolic_context = StatelessSymbolicContext(
|
|
dynamic_sizes=dynamic_sizes,
|
|
constraint_sizes=constraint_sizes, # type: ignore[arg-type]
|
|
)
|
|
t_id = id(t)
|
|
assert mode.shape_env is not None
|
|
if t_id in t_constraints:
|
|
for i, constraint in t_constraints[t_id].items():
|
|
src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i)
|
|
sources[(t_id, i)].append(src)
|
|
if isinstance(constraint, _RelaxedConstraint):
|
|
continue
|
|
symbolic_context.constraint_sizes[i] = constraint.constraint_range
|
|
mode.shape_env.source_name_to_debug_name[src.name()] = constraint.name # type: ignore[assignment]
|
|
fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context)
|
|
mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context)) # type: ignore[union-attr]
|
|
return fake
|
|
|
|
|
|
def make_fake_inputs(
|
|
nn_module,
|
|
args,
|
|
kwargs,
|
|
dynamic_shapes,
|
|
_is_torch_jit_trace=False,
|
|
allow_complex_guards_as_runtime_asserts=False,
|
|
):
|
|
"""
|
|
Given an nn module, example inputs, and constraints, return a new fake mode,
|
|
fake inputs created in that mode whose dynamic shape dimensions are constrained
|
|
by the given ranges, and sources for pairs of dynamic shape dimensions that are
|
|
constrained to be equal.
|
|
"""
|
|
# TODO(avik): refactor Dynamo to avoid duplication of the following code
|
|
# between non-strict and strict.
|
|
# Specifically, here (non-strict) we do the following pre-tracing steps:
|
|
# - Fakify inputs.
|
|
# - Process input shape equalities.
|
|
# In strict, these steps are spread across multiple files:
|
|
# - output_graph.py fakifies inputs.
|
|
# - [post-tracing] guards.py processes input shape equalities.
|
|
|
|
combined_args = _combine_args(nn_module, args, kwargs)
|
|
_check_dynamic_shapes(combined_args, dynamic_shapes)
|
|
constraints = _process_dynamic_shapes(combined_args, dynamic_shapes)
|
|
t_constraints: dict[int, dict[int, Constraint]] = defaultdict(dict)
|
|
for constraint in constraints:
|
|
t_constraints[constraint.t_id][constraint.dim] = constraint
|
|
|
|
context = torch._guards.TracingContext.try_get()
|
|
if context is not None:
|
|
# This occurs when we are exporting within dynamo. There already exists
|
|
# a toplevel TracingContext with a fake mode, so we do not want to
|
|
# create another fake mode.
|
|
fake_mode = context.fake_mode
|
|
elif not _is_torch_jit_trace:
|
|
code = nn_module.forward.__code__
|
|
co_fields = {
|
|
"co_name": code.co_name,
|
|
"co_filename": code.co_filename,
|
|
"co_firstlineno": code.co_firstlineno,
|
|
}
|
|
fake_mode = FakeTensorMode(
|
|
shape_env=ShapeEnv(
|
|
tracked_fakes=[],
|
|
co_fields=co_fields,
|
|
prefer_deferred_runtime_asserts_over_guards=True,
|
|
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
|
),
|
|
allow_non_fake_inputs=True,
|
|
export=True,
|
|
)
|
|
else:
|
|
fake_mode = FakeTensorMode(
|
|
shape_env=ShapeEnv(
|
|
tracked_fakes=[],
|
|
prefer_deferred_runtime_asserts_over_guards=True,
|
|
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
|
),
|
|
allow_non_fake_inputs=True,
|
|
)
|
|
if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None:
|
|
raise ValueError(
|
|
"Detected fake_mode does not have a shape_env with tracked fakes. "
|
|
"If you constructed the module under a FakeTensorMode, "
|
|
"please initialize it like: FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[]))"
|
|
)
|
|
|
|
with fake_mode:
|
|
# FIXME(ycao) ScriptMethod doesn't have signature, I am using an empty one to unblock
|
|
if not _is_torch_jit_trace:
|
|
original_signature = inspect.signature(nn_module.forward)
|
|
else:
|
|
original_signature = None
|
|
sources: dict[tuple[int, int], list[Source]] = defaultdict(list)
|
|
fake_args, fake_kwargs = tree_map_with_path(
|
|
lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
|
|
(args, kwargs),
|
|
)
|
|
|
|
names: dict[str, tuple[int, int]] = {}
|
|
source_pairs: list[tuple[Source, Source]] = []
|
|
derived_equalities: list[tuple[Source, Union[Source, Symbol], Callable]] = []
|
|
phantom_symbols: dict[str, Symbol] = {}
|
|
relaxed_sources: set[Source] = set()
|
|
for constraint in constraints:
|
|
torch.export.dynamic_shapes._process_equalities(
|
|
constraint,
|
|
lambda t_id, dim: sources[(t_id, dim)],
|
|
fake_mode.shape_env,
|
|
names,
|
|
source_pairs,
|
|
derived_equalities,
|
|
phantom_symbols,
|
|
relaxed_sources,
|
|
)
|
|
|
|
equalities_inputs = EqualityConstraint(
|
|
source_pairs=source_pairs,
|
|
derived_equalities=derived_equalities,
|
|
phantom_symbols=list(phantom_symbols.values()),
|
|
relaxed_sources=relaxed_sources,
|
|
warn_only=False,
|
|
)
|
|
return (
|
|
fake_mode,
|
|
fake_args,
|
|
fake_kwargs,
|
|
equalities_inputs,
|
|
original_signature,
|
|
dynamic_shapes,
|
|
)
|
|
|
|
|
|
def _flatten_dynamic_shapes(
|
|
combined_args: dict[str, Any],
|
|
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]],
|
|
) -> list[Any]:
|
|
flat_shapes = []
|
|
|
|
def _tree_map_helper(path, t, shape):
|
|
nonlocal flat_shapes
|
|
flat_shapes.append(shape)
|
|
|
|
_tree_map_with_path(_tree_map_helper, combined_args, dynamic_shapes)
|
|
return flat_shapes
|
|
|
|
|
|
def _clean_dynamic_markers(tensor: torch.Tensor) -> None:
|
|
for attr in [
|
|
"_dynamo_weak_dynamic_indices",
|
|
"_dynamo_dynamic_indices",
|
|
"_dynamo_dynamic_range",
|
|
"_dynamo_static_indices",
|
|
"_dynamo_unbacked_indices",
|
|
]:
|
|
if hasattr(tensor, attr):
|
|
delattr(tensor, attr)
|
|
|
|
|
|
def produce_guards_and_solve_constraints(
|
|
fake_mode: FakeTensorMode,
|
|
gm: torch.fx.GraphModule,
|
|
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
|
|
equalities_inputs: EqualityConstraint,
|
|
original_signature: inspect.Signature,
|
|
_is_torch_jit_trace=False,
|
|
):
|
|
"""
|
|
Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions,
|
|
and a graph module, produce guards on the fake mode's shape env (raising constraint
|
|
violations if any), solve (to suggest simplifications or fixes).
|
|
Dynamo already performs this, so this is for non-strict mode.
|
|
|
|
Additional inputs:
|
|
equalities_inputs: the equality constraints to use for guards
|
|
original_signature: the signature of the forward method
|
|
"""
|
|
shape_env = fake_mode.shape_env
|
|
assert shape_env is not None
|
|
assert shape_env.tracked_fakes is not None
|
|
|
|
placeholders = [tf.fake for tf in shape_env.tracked_fakes]
|
|
sources = [tf.source for tf in shape_env.tracked_fakes]
|
|
input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes]
|
|
constraint_violation_error = None
|
|
try:
|
|
shape_env.produce_guards(
|
|
placeholders,
|
|
sources,
|
|
input_contexts=input_contexts,
|
|
equalities_inputs=equalities_inputs,
|
|
ignore_static=False,
|
|
)
|
|
except ConstraintViolationError as e:
|
|
constraint_violation_error = e
|
|
|
|
shape_env.frozen = True
|
|
dim_constraints = shape_env.dim_constraints
|
|
if dim_constraints is None:
|
|
# Expected when shape_env.produce_guards throws an early constraint violation error.
|
|
# There is nothing to solve for in this case.
|
|
# TODO(avik): Maybe record the constraint violation error instead and replay later?
|
|
assert constraint_violation_error
|
|
raise constraint_violation_error
|
|
dim_constraints.solve()
|
|
forced_specializations = dim_constraints.forced_specializations()
|
|
if not _is_torch_jit_trace:
|
|
msg = dim_constraints.prettify_results(
|
|
original_signature,
|
|
dynamic_shapes, # type: ignore[arg-type]
|
|
constraint_violation_error,
|
|
forced_specializations, # type: ignore[arg-type]
|
|
)
|
|
else:
|
|
# FIXME(ycao): This is a hack to get around missing signature from ScriptMethod
|
|
msg = "dummy constraint violation message"
|
|
if constraint_violation_error:
|
|
constraint_violation_error.args = (constraint_violation_error.args[0] + msg,)
|
|
elif forced_specializations:
|
|
constraint_violation_error = ConstraintViolationError(msg)
|
|
if constraint_violation_error:
|
|
raise constraint_violation_error
|
|
|
|
|
|
def make_constraints(
|
|
fake_mode: FakeTensorMode,
|
|
gm: torch.fx.GraphModule,
|
|
combined_args: dict[str, Any],
|
|
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
|
|
num_lifted_inputs: int,
|
|
):
|
|
"""
|
|
Given a fake mode's shape env and user-specified dynamic shapes,
|
|
return the resulting range constraints and equality constraints.
|
|
|
|
Additional args:
|
|
num_lifted_inputs: the number of non-user-input placeholder nodes in the graph
|
|
(used only to enumerate the user-input nodes)
|
|
"""
|
|
|
|
shape_env = fake_mode.shape_env
|
|
assert shape_env is not None
|
|
inline_constraints = gm.meta.get("inline_constraints", [])
|
|
range_constraints = {
|
|
symbol: inline_constraints[symbol] for symbol in inline_constraints
|
|
}
|
|
if not dynamic_shapes:
|
|
return range_constraints
|
|
|
|
# clean up dynamic markers from tensors
|
|
for arg in pytree.tree_flatten(combined_args)[0]:
|
|
if isinstance(arg, torch.Tensor):
|
|
_clean_dynamic_markers(arg)
|
|
|
|
# get individual dynamic shapes spec for each input
|
|
if not isinstance(dynamic_shapes, dict):
|
|
assert isinstance(dynamic_shapes, (tuple, list))
|
|
combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
|
|
flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes)
|
|
|
|
# check number of shapes vs. number of inputs
|
|
num_placeholders = [node.op == "placeholder" for node in gm.graph.nodes].count(True)
|
|
assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs
|
|
|
|
input_dims = defaultdict(list)
|
|
free_symbols = set()
|
|
for input_index, node in enumerate(gm.graph.nodes):
|
|
if input_index < num_lifted_inputs or node.op != "placeholder":
|
|
continue
|
|
if _is_constant_argument(node.meta["val"]) or isinstance(
|
|
node.meta["val"], CustomObjArgument
|
|
):
|
|
continue
|
|
shape_spec = flat_dynamic_shapes[input_index - num_lifted_inputs]
|
|
for i, d in enumerate(node.meta["val"].shape):
|
|
if isinstance(d, torch.SymInt) and not d.node.expr.is_number:
|
|
# Look up the range constraint for the symbol corresponding to this shape dimension
|
|
# and store it indexed by the symbolic expression corresponding to it.
|
|
# NOTE(avik): Use node._expr instead of node.expr for the lookup here because
|
|
# we want the symbol, not its replacement, which could be an expression. Maybe
|
|
# there's a better way to do this, e.g., by (re)computing value ranges for expressions?
|
|
dim = shape_spec[i] if shape_spec else None
|
|
if dim is None or isinstance(dim, _DimHint):
|
|
range_constraints[d.node.expr] = shape_env.var_to_range[
|
|
d.node._expr
|
|
]
|
|
else:
|
|
range_constraints[d.node.expr] = ValueRanges(
|
|
lower=dim.min, upper=dim.max
|
|
)
|
|
input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i))
|
|
free_symbols.update(d.node.expr.free_symbols)
|
|
|
|
for symbol in free_symbols:
|
|
if symbol not in range_constraints:
|
|
# Placeholders can have symbolic shapes that are derived expressions.
|
|
# The above code will record direct range constraints for them
|
|
# so that we can do runtime assertions. In addition, for serde checks
|
|
# we want to record range constraints for their root symbols.
|
|
range_constraints[symbol] = shape_env.var_to_range[symbol]
|
|
|
|
return range_constraints
|
|
|
|
|
|
def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap:
|
|
"""Search the module hierarchy, gathering up all tensor and ScriptObject constants.
|
|
|
|
Returns a dictionary mapping hash(value) to the name of the constant. We
|
|
have to abuse `hash` here unfortunately, see: [ScriptObject hash].
|
|
"""
|
|
constants = ConstantAttrMap()
|
|
buffers_parameters = set(m.buffers())
|
|
buffers_parameters.update(m.parameters())
|
|
|
|
def inner(m: torch.nn.Module, prefix_atoms: list[str], constants):
|
|
for k, v in m.__dict__.items():
|
|
if isinstance(
|
|
v,
|
|
(
|
|
torch.Tensor,
|
|
torch.ScriptObject,
|
|
FakeScriptObject,
|
|
),
|
|
):
|
|
if v in buffers_parameters:
|
|
# filter out buffers and parameters, leaving only constants
|
|
continue
|
|
|
|
fqn = ".".join(prefix_atoms + [k])
|
|
constants.add(v, fqn)
|
|
for k, v in m.named_children():
|
|
inner(v, prefix_atoms + [k], constants)
|
|
|
|
inner(m, [], constants)
|
|
return constants
|
|
|
|
|
|
def _get_graph_inputs_of_type_nn_module(
|
|
args: Optional[tuple[tuple[Any], dict[Any, Any]]],
|
|
) -> set[type[torch.nn.Module]]:
|
|
if args is None:
|
|
return set()
|
|
module_types = set()
|
|
for arg in pytree.tree_leaves(args):
|
|
if isinstance(arg, torch.nn.Module):
|
|
module_types.add(type(arg))
|
|
return module_types
|
|
|
|
|
|
def _enter_enable_graph_inputs_of_type_nn_module(
|
|
module_types: set[type[torch.nn.Module]],
|
|
) -> None:
|
|
for t in module_types:
|
|
torch._export.utils.register_module_as_pytree_input_node(t)
|
|
|
|
|
|
def _exit_enable_graph_inputs_of_type_nn_module(
|
|
module_types: set[type[torch.nn.Module]],
|
|
) -> None:
|
|
for t in module_types:
|
|
torch._export.utils.deregister_module_as_pytree_input_node(t)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _enable_graph_inputs_of_type_nn_module(
|
|
args: Optional[tuple[tuple[Any], dict[Any, Any]]],
|
|
):
|
|
if args is None:
|
|
yield
|
|
return
|
|
|
|
module_types = _get_graph_inputs_of_type_nn_module(args)
|
|
_enter_enable_graph_inputs_of_type_nn_module(module_types)
|
|
try:
|
|
yield
|
|
finally:
|
|
_exit_enable_graph_inputs_of_type_nn_module(module_types)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _fakify_module_inputs(
|
|
args: tuple[Any],
|
|
kwargs: dict[Any, Any],
|
|
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
|
|
):
|
|
# This context manager is used to fakify module inputs.
|
|
# Inputs:
|
|
# args, kwargs: the args and kwargs containing module inputs that haven't been fakified.
|
|
# fake_mode: the fake mode to be used for fakifying script objects. It's the same mode that fakify input tensors.
|
|
|
|
ctxs = [_enable_graph_inputs_of_type_nn_module((args, kwargs))]
|
|
for arg in pytree.tree_leaves((args, kwargs)):
|
|
if isinstance(arg, torch.nn.Module):
|
|
fake_params_buffers = _fakify_params_buffers(fake_mode, arg)
|
|
ctxs.append(
|
|
torch.nn.utils.stateless._reparametrize_module(
|
|
arg,
|
|
fake_params_buffers,
|
|
tie_weights=True,
|
|
strict=True,
|
|
stack_weights=True,
|
|
)
|
|
)
|
|
with contextlib.ExitStack() as stack:
|
|
for ctx in ctxs:
|
|
stack.enter_context(ctx)
|
|
yield
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _fakify_script_objects(
|
|
mod: torch.nn.Module,
|
|
args: tuple[Any],
|
|
kwargs: dict[Any, Any],
|
|
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
|
|
):
|
|
# This context manager is used to fakify script objects into FakeScriptObject.
|
|
# Inputs:
|
|
# mod: the module to be exported, it (and its recursive submodules)'s script object attrs haven't been fakified.
|
|
# args, kwargs: the args and kwargs inputs for mod, script object inputs haven't been fakified.
|
|
# fake_mode: the fake mode to be used for fakifying script objects. It's the same mode that fakify input tensors.
|
|
#
|
|
# Returns:
|
|
# mod: the patched module, its (and its recursive submodules) script object attrs have been fakified.
|
|
# fake_args, fake_kwargs: new fakified args and kwargs.
|
|
# Script object inputs have been fakified. Don't touch the tensors.
|
|
# fake_constant_attrs: a new map from FakeScriptObject to the fqn of the original script object.
|
|
# fake_to_real: a mapping between FakeScriptObject and the original script object in order to un-do the patching.
|
|
|
|
constant_attrs: ConstantAttrMap = _gather_constant_attrs(mod)
|
|
assert not any(
|
|
isinstance(obj, FakeScriptObject) for obj in constant_attrs.values()
|
|
), "Mod shouldn't contain any FakeScriptObject."
|
|
assert not pytree.tree_any(
|
|
lambda obj: isinstance(obj, FakeScriptObject), (args, kwargs)
|
|
), "args and kwargs shouldn't contain any FakeScriptObject."
|
|
|
|
patched_attr = {}
|
|
fake_constant_attrs = ConstantAttrMap()
|
|
fake_to_real = {}
|
|
|
|
def _maybe_fakify_obj(obj):
|
|
fake_obj = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, obj)
|
|
fake_to_real[fake_obj] = obj
|
|
return fake_obj
|
|
|
|
def _leaf_mod_and_attr(
|
|
mod: torch.nn.Module, attr_fqn: str
|
|
) -> tuple[torch.nn.Module, str]:
|
|
*prefix_attr, last_attr = attr_fqn.split(".")
|
|
cur_mod = mod
|
|
for attr in prefix_attr:
|
|
cur_mod = getattr(cur_mod, attr)
|
|
return cur_mod, last_attr
|
|
|
|
try:
|
|
for obj, fqns in constant_attrs.items():
|
|
if isinstance(obj, torch.ScriptObject):
|
|
fake_script_obj = _maybe_fakify_obj(obj)
|
|
for fqn in fqns:
|
|
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
|
|
assert obj is getattr(cur_mod, attr)
|
|
setattr(cur_mod, attr, fake_script_obj)
|
|
fake_constant_attrs.add(fake_script_obj, fqn)
|
|
patched_attr[fqn] = obj
|
|
else:
|
|
for fqn in fqns:
|
|
fake_constant_attrs.add(obj, fqn)
|
|
|
|
fake_args, fake_kwargs = pytree.tree_map_only(
|
|
torch.ScriptObject, _maybe_fakify_obj, (args, kwargs)
|
|
)
|
|
yield (mod, fake_args, fake_kwargs, fake_constant_attrs, fake_to_real)
|
|
finally:
|
|
for fqn, orig_obj in patched_attr.items():
|
|
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
|
|
setattr(cur_mod, attr, orig_obj)
|
|
|
|
|
|
class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
|
|
"""
|
|
1. Handles data-dependent errors raised by torch function calls in non-strict.
|
|
|
|
Any data-dependent error is due to some condition on unbacked symints
|
|
that cannot be resolved. A mechanical way of fixing the error is to use
|
|
a torch._check() call to assert either that condition or its negation.
|
|
The handler suggests these options as code and points to the location
|
|
of the torch function call that raised the error as part of the error
|
|
message shown to the user, who can then simply select and copy-paste
|
|
a suggested fix at that location.
|
|
|
|
NOTE: Not all data-dependent errors are raised by torch function calls.
|
|
In particular, conditions on unbacked symints can appear outside such
|
|
calls, and as such are not handled here.
|
|
|
|
2. Overrides torch functions that are known to cause problems in non-strict.
|
|
|
|
Certain Python features, such as indexing/slicing, cannot be intercepted
|
|
in non-strict. When these features need special handling in the compiler,
|
|
tracing can fail in non-strict (yet surprisingly succeed in strict).
|
|
Fortunately, redirecting to other torch functions can often fix such issues.
|
|
|
|
3. Handles line-of-code logging for each torch function call in non-strict.
|
|
|
|
Usage: TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC=1 TORCH_LOGS="+export" ...
|
|
"""
|
|
|
|
def _override(self, func, args, kwargs):
|
|
if func is torch.tensor:
|
|
# Redirect to Python implementation of torch.tensor for data with symints.
|
|
# NOTE(avik): We don't unconditionally redirect to this implementation
|
|
# because it has some known incompletenesses, e.g., it doesn't support
|
|
# empty data. See https://github.com/pytorch/pytorch/issues/143216
|
|
if any(
|
|
isinstance(a, torch.SymInt) for a in pytree.tree_flatten(args[0])[0]
|
|
):
|
|
return torch._refs.tensor, args, kwargs
|
|
if func.__name__ == "__getitem__" and isinstance(args[0], torch.Tensor):
|
|
# Redirect to torch.select for indexing with symint.
|
|
if isinstance(args[1], torch.SymInt):
|
|
return torch.select, [args[0], 0, args[1]], {}
|
|
return func, args, kwargs
|
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs or {}
|
|
if torch.compiler.is_dynamo_compiling():
|
|
return func(*args, **kwargs)
|
|
|
|
if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
|
|
frame = _find_user_code_frame()
|
|
if frame is not None:
|
|
log.debug(
|
|
"%s called at %s:%s in %s",
|
|
func.__qualname__,
|
|
frame.f_code.co_filename,
|
|
frame.f_lineno,
|
|
frame.f_code.co_name,
|
|
)
|
|
|
|
func, args, kwargs = self._override(func, args, kwargs)
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except GuardOnDataDependentSymNode as e:
|
|
_suggest_fixes_for_data_dependent_error_non_strict(e)
|
|
raise
|