Files
pytorch/torch/_inductor/compile_fx.py
James Wu 1b772de397 Still run TritonBundler with BundledAOTAutogradCache, save autotune results (#158048)
When running BundledAOTAutogradCache with precompile, we still need to run triton bundling so that the precompiled CompiledFxGraph has triton cuda kernels. We also pre save the autotune results in the precompile artifact.

It would be even better to pre trim the cuda kernels on save and apply them, which we can work on later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158048
Approved by: https://github.com/zhxchen17
2025-07-22 14:12:21 +00:00

2702 lines
104 KiB
Python

from __future__ import annotations
import contextlib
import enum
import functools
import io
import itertools
import json
import logging
import os
import sys
import time
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import AbstractContextManager
from dataclasses import dataclass
from inspect import currentframe
from itertools import count
from operator import attrgetter
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Never, override, ParamSpec, Protocol, TypedDict, Unpack
from unittest import mock
import torch._inductor.async_compile
import torch.fx
import torch.utils._pytree as pytree
from functorch.compile import min_cut_rematerialization_partition
from torch import fx
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo import (
compiled_autograd,
config as dynamo_config,
logging as dynamo_logging,
utils as dynamo_utils,
)
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.repro.after_aot import wrap_compiler_debug
from torch._dynamo.utils import (
chromium_event_timed,
CompileEventLogger,
counters,
detect_fake_mode,
dynamo_timed,
flatten_graph_inputs,
get_metrics_context,
lazy_format_graph_code,
set_feature_use,
)
from torch._functorch import config as functorch_config
from torch._functorch._aot_autograd.subclass_parametrization import (
unwrap_tensor_subclass_parameters,
)
from torch._functorch.aot_autograd import (
aot_export_module,
GraphOutputName,
make_boxed_func,
SerializableAOTDispatchCompiler,
)
from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log
from torch._inductor.cudagraph_utils import (
BoxedDeviceIndex,
format_default_skip_message,
log_cudagraph_skip_and_bump_counter,
PlaceholderInfo,
)
from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._inductor.output_code import (
CompiledAOTI,
CompiledFxGraph,
CompiledFxGraphConstantsWithGm,
get_expanded_dims,
index_expanded_dims,
OutputCode,
)
from torch._inductor.runtime.cache_dir_utils import cache_dir
from torch._inductor.utils import (
BoxedBool,
count_tangents,
fresh_cache,
get_all_devices,
InputType,
is_gpu,
should_assume_input_aligned,
should_use_remote_fx_graph_cache,
tensor_is_aligned,
)
from torch._library.fake_class_registry import FakeScriptObject
from torch._logging import trace_structured
from torch._utils_internal import compile_time_strobelight_meta
from torch.fx import GraphModule
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.monitor import _WaitCounter
from torch.utils._ordered_set import OrderedSet
from .._dynamo.backends.common import aot_autograd
from .._dynamo.exc import ShortenTraceback, SkipFrame
from ..fx._lazy_graph_module import _use_lazy_graph_module
from ..fx.graph import _PyTreeCodeGen
from ..utils._triton import has_triton
from . import config, metrics
from .codegen.common import get_wrapper_codegen_for_device, init_backend_registration
from .debug import DebugContext
from .decomposition import select_decomp_table
from .exc import InductorError
from .fx_passes.joint_graph import joint_graph_passes
from .fx_passes.post_grad import post_grad_passes, view_to_reshape
from .fx_passes.pre_grad import pre_grad_passes
from .graph import GraphLowering
from .ir import get_device_type, IRNode
from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401
from .triton_bundler import TritonBundler
from .utils import (
align_inputs_from_check_idxs,
clone_preserve_strides,
copy_misaligned_inputs,
get_cloned_parameter_buffer_name,
get_first_incompatible_cudagraph_node,
maybe_get_suppress_shape_guards_ctx,
output_node,
remove_unaligned_input_idxs,
shape_env_from_inputs,
)
from .virtualized import V
if TYPE_CHECKING:
from collections.abc import Generator, Sequence
from torch._inductor.output_code import _StrideExprStr
from torch._ops import OpOverload
from torch.export.pt2_archive._package_weights import Weights
from .ir import ExternKernelNode
_P = ParamSpec("_P")
_T = TypeVar("_T")
if TYPE_CHECKING or not config.is_fbcode():
# no-op decorator
def time_and_log(attr: str) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
return dynamo_utils.identity
def log_optimus_to_scuba(*args: object, **kwargs: object) -> None:
pass
else:
from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log
if TYPE_CHECKING:
from torch._functorch._aot_autograd.schemas import (
FQN,
GraphInputName,
GraphSignature,
)
class FxCompileMode(enum.Enum):
NORMAL = 0
# For testing - use the serde FxCompile scheme to debug serialization and
# deserialization of GraphMoule and CompiledFxGraph.
SERIALIZE = 1
# Compile using a subprocess instead of in-process.
SUBPROCESS = 2
@dataclass
class FxCompileConfig:
mode: FxCompileMode
use_async: bool
use_progressive: bool
def _fx_compile_mode_default() -> FxCompileConfig:
name = "TORCHINDUCTOR_FX_COMPILE_MODE"
value = os.environ.get(name)
if value is None:
return FxCompileConfig(FxCompileMode.NORMAL, False, False)
use_async = False
use_progressive = False
if value.lower().startswith("progressive+"):
use_progressive = True
value = value[12:]
if value.lower().startswith("async+"):
use_async = True
value = value[6:]
try:
value = value.upper()
return FxCompileConfig(FxCompileMode[value], use_async, use_progressive)
except KeyError:
import logging
log = logging.getLogger(__name__)
log.error(
"Invalid value of %s for %s. Expected one of %s. Using default.",
value,
name,
", ".join(sorted(repr(x) for x in FxCompileMode.__members__.keys())),
)
# Remove from the environment so subprocesses don't ALSO complain.
os.environ.pop(name)
return FxCompileConfig(FxCompileMode.NORMAL, False, False)
def _get_progression_configs() -> list[dict[str, Any]]:
# TODO make this configurable
return [
{"max_autotune": True},
]
_fx_compile_config = _fx_compile_mode_default()
fx_compile_mode = _fx_compile_config.mode
fx_compile_async = _fx_compile_config.use_async
fx_compile_progressive = _fx_compile_config.use_progressive
log = logging.getLogger(__name__)
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
pre_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "pre_grad_graphs")
post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs")
static_inputs_log = torch._logging.getArtifactLogger(
__name__, "cudagraph_static_inputs"
)
inductor_metrics_log = torch._logging.getArtifactLogger(__name__, "inductor_metrics")
def get_static_input_idxs(num_fixed: int) -> list[int]:
# If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes
# of cudagraphs. Rather than copying these into cudagraph-owned memory
# like we do for normal inputs on each run, we will re-record a cudagraph if these
# parameter locations change.
context = torch._guards.TracingContext.try_get()
fixed = list(range(num_fixed))
if not context or not context.fw_metadata:
return fixed
return context.fw_metadata.static_input_indices
def record_original_output_strides(gm: GraphModule) -> None:
output_node = gm.graph.find_nodes(op="output")[0]
output_strides = []
if not isinstance(output_node.args[0], torch.fx.Node):
output_node_args = output_node.args[0]
else:
output_node_args = output_node.args
for output in output_node_args:
if (
isinstance(output, torch.fx.Node)
and (val := output.meta.get("val")) is not None
and isinstance(val, torch.Tensor)
):
output_strides.append(val.stride())
else:
output_strides.append(None)
output_node.meta["original_output_strides"] = output_strides
def _recursive_record_original_output_strides(gm: GraphModule) -> None:
# invoke_subgraph HOP requires output strides to be respected
for node in gm.graph.find_nodes(
op="call_function", target=torch.ops.higher_order.invoke_subgraph
):
subgraph = getattr(gm, node.args[0].target)
_recursive_record_original_output_strides(subgraph)
record_original_output_strides(gm)
def _recursive_record_user_visible_output_idxs(gm: GraphModule) -> None:
# invoke_subgraph HOP requires output strides to be respected
for node in gm.graph.find_nodes(
op="call_function", target=torch.ops.higher_order.invoke_subgraph
):
subgraph = getattr(gm, node.args[0].target)
for node in subgraph.graph.find_nodes(op="output"):
node.meta["user_visible_output_idxs"] = [
idx
for idx in range(len(node.args[0]))
if isinstance(node.args[0][idx], torch.fx.Node)
]
_recursive_record_user_visible_output_idxs(subgraph)
@functools.lru_cache(None)
def _step_logger() -> Callable[..., None]:
return dynamo_logging.get_step_logger(log)
@functools.cache
def _warn_tf32_disabled() -> None:
if (
torch.cuda.is_available()
and not torch.backends.cuda.matmul.allow_tf32
and torch.cuda.get_device_capability() >= (8, 0)
):
warnings.warn(
"TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. "
"Consider setting `torch.set_float32_matmul_precision('high')` for better performance."
)
def _resolve_name_collision(mod: GraphModule, gm: GraphModule) -> None:
"""
In aot_export_module (make_fx), we create get_attr nodes with name prefix
"_tensor_constant" and "_torchbind_obj". See Tracer.create_arg() in
torch/fx/_symbolic_trace.py
However, this might result in name collision if the original mod already
has a different buffer with the same name.
We resolve this potential name collision here by changing the target name
with a new number post fix.
"""
existing_keys = OrderedSet(
[name for name, val in mod.named_parameters(remove_duplicate=False)]
)
existing_keys.update(
OrderedSet([name for name, val in mod.named_buffers(remove_duplicate=False)])
)
def find_smallest_i(graph: fx.Graph, prefix: str) -> int:
i = 0
for node in graph.nodes:
if node.op == "get_attr" and node.target.startswith(prefix):
if len(node.target) > len(prefix):
post_fix = node.target.split(prefix)[-1]
if post_fix.isdigit():
i = max(i, int(post_fix))
for key in existing_keys:
if key.startswith(prefix):
if len(key) > len(prefix):
post_fix = key.split(prefix)[-1]
if post_fix.isdigit():
i = max(i, int(post_fix))
return i + 1
for node in gm.graph.nodes:
if node.op == "get_attr":
target_name = node.target
if not target_name.startswith(
"_tensor_constant"
) and not target_name.startswith("_torchbind_obj"):
continue
if not hasattr(mod, target_name):
continue
gm_target = attrgetter(target_name)(gm)
model_target = attrgetter(target_name)(mod)
if isinstance(gm_target, FakeScriptObject):
if (
isinstance(model_target, FakeScriptObject)
and gm_target.real_obj is model_target.real_obj
):
continue
elif (
torch.equal(gm_target, model_target)
and gm_target.dtype == model_target.dtype
):
continue
prefix = (
"_tensor_constant"
if target_name.startswith("_tensor_constant")
else "_torchbind_obj"
)
new_id = find_smallest_i(gm.graph, prefix)
new_target_name = f"{prefix}{new_id}"
node.target = new_target_name
setattr(gm, new_target_name, gm_target)
existing_keys.add(new_target_name)
def _unlift_graph(
mod: GraphModule, gm: GraphModule, graph_signature: GraphSignature
) -> GraphModule:
from torch.export.unflatten import _assign_attr, _AttrKind
_resolve_name_collision(mod, gm)
state_dict: dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {}
for name, param in mod.named_parameters(remove_duplicate=False):
state_dict[name] = param
_assign_attr(
param,
gm,
name,
attr_kind=_AttrKind.PARAMETER,
)
for name, buffer in mod.named_buffers(remove_duplicate=False):
state_dict[name] = buffer
_assign_attr(
buffer,
gm,
name,
attr_kind=_AttrKind.BUFFER,
)
placeholder_nodes = gm.graph.find_nodes(op="placeholder")
lifted_inputs: list[Optional[FQN]] = []
# In AOTI, module parameters and buffers are not lifted as graph inputs.
# As a result, mutation to buffers has side effect which makes their initial
# values different from Eager. So we clone them here as a copy.
# We are not cloning for parameters, although it will be needed if we want to
# support training.
for node in placeholder_nodes:
node_name = node.name
if node_name in graph_signature.inputs_to_parameters:
parameter_name = graph_signature.inputs_to_parameters[node_name]
lifted_inputs.append(parameter_name)
elif node_name in graph_signature.inputs_to_buffers:
buffer_name = graph_signature.inputs_to_buffers[node_name]
lifted_inputs.append(buffer_name)
gm.meta[get_cloned_parameter_buffer_name(buffer_name)] = (
clone_preserve_strides(state_dict[buffer_name])
)
else:
assert node_name in graph_signature.user_inputs
lifted_inputs.append(None)
from torch.export._unlift import _unlift
outputs: tuple[torch.fx.Node, ...] = tuple(gm.graph.output_node().args[0]) # type: ignore[arg-type]
mutated_outputs = []
buffer_mutations = graph_signature.buffers_to_mutate
user_input_mutations = graph_signature.user_inputs_to_mutate
output_tokens = graph_signature.output_tokens
for idx, out in enumerate(outputs):
value: Optional[Union[FQN, GraphInputName]] = None
if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
name = GraphOutputName(out.name)
if name in buffer_mutations:
value = buffer_mutations[name]
elif name in user_input_mutations:
value = user_input_mutations[name]
mutated_outputs.append(value)
unlifted_gm = _unlift(
gm,
lifted_inputs,
mutated_outputs,
pytree.LeafSpec(),
None,
)
return unlifted_gm
def _get_subgraph_names(
gm: GraphModule, skip_invoke_subgraph: bool = False
) -> Generator[str, None, None]:
all_subgraph_names: OrderedSet[str] = OrderedSet(
x.target for x in gm.graph.find_nodes(op="get_attr")
)
fx_subgraph_names: OrderedSet[str] = OrderedSet()
for child_name, child_module in gm.named_children():
# Sometimes an owning_module can have unused children. Skip them
# by checking them from get_attr node targets.
if child_name in all_subgraph_names and isinstance(
child_module, torch.fx.GraphModule
):
fx_subgraph_names.add(child_name)
if skip_invoke_subgraph:
for node in gm.graph.find_nodes(
op="call_function", target=torch.ops.higher_order.invoke_subgraph
):
fx_subgraph_names.discard(node.args[0].target)
yield from fx_subgraph_names
def _recursive_pre_grad_passes(
gm: GraphModule,
example_inputs: Sequence[InputType],
) -> GraphModule:
with dynamo_timed(
"_recursive_pre_grad_passes",
log_pt2_compile_event=True,
dynamo_compile_column_us="pre_grad_pass_time_us",
):
add_passes = config.add_pre_grad_passes
remove_passes = config.remove_pre_grad_passes
for subgraph_name in _get_subgraph_names(gm):
subgraph = getattr(gm, subgraph_name)
# as we don't have recursive example inputs, passing empty set here
new_subgraph = _recursive_pre_grad_passes(subgraph, ())
setattr(gm, subgraph_name, new_subgraph)
return pre_grad_passes(gm, example_inputs, add_passes, remove_passes)
def _recursive_joint_graph_passes(
gm: GraphModule, skip_invoke_subgraph: bool = False
) -> None:
with dynamo_timed(
"_recursive_joint_graph_passes",
log_pt2_compile_event=True,
dynamo_compile_column_us="joint_graph_pass_time_us",
):
# invoke_subgraph already runs the _recursive_joint_graph_passes. In
# AOTAutograd, `run_joint_graph_passes_on_hops` partitions the
# invoke_subgraph HOP before calling the partitioner on the outer graph.
# AOTAutograd has access to partition_fn, which internally calls the
# `_recursive_joint_graph_passes` for the subgraph. So, skip recursing
# skip_invoke_subgraph.
for subgraph_name in _get_subgraph_names(gm, skip_invoke_subgraph):
subgraph = getattr(gm, subgraph_name)
_recursive_joint_graph_passes(subgraph, skip_invoke_subgraph)
joint_graph_passes(gm)
def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) -> None:
with dynamo_timed(
"_recursive_post_grad_passes",
log_pt2_compile_event=True,
dynamo_compile_column_us="post_grad_pass_time_us",
):
for subgraph_name in _get_subgraph_names(gm):
subgraph = getattr(gm, subgraph_name)
_recursive_post_grad_passes(subgraph, is_inference)
post_grad_passes(gm, is_inference)
def split_const_gm(
gm: GraphModule,
skip_constructor: bool = True,
lifted_constant_names: Optional[list[str]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> tuple[GraphModule, dict[str, int]]:
"""
This function takes an GraphModule input "gm".
The gm will be split into 2 components,
1) const_gm, which consists the subgraph of gm that can be constant folded.
2) gm (being inplace modified,) which returns the graph after constant folding.
If an additional "lifted_constants" argument is passed in, we will assume the gm has
been lifted and run the transformation accordingly.
When a "skip_folding_node_fn" callback is passed, we will skip constant folding on
the nodes for which the callback returns True.
const_output_index is a mapping of corresponding node name from gm to the
output index of const_gm.
Returns (const_gm, const_output_index)
"""
from torch._inductor.constant_folding import (
CONST_MODULE_TAG,
META_TAG,
MODULE_TAG,
replace_node_with_constant,
run_and_get_constant_graph,
)
const_gm = run_and_get_constant_graph(
gm, skip_constructor, lifted_constant_names, skip_folding_node_fn
)
const_result = const_gm() if lifted_constant_names is None else None
const_outputs = {
x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0])
}
to_erase_node = []
to_replace_node = []
const_output_index = {}
for node in gm.graph.nodes:
if node.name in const_outputs:
to_replace_node.append(node)
elif node.meta[META_TAG] == CONST_MODULE_TAG and node.op != "placeholder":
to_erase_node.append(node)
for node in to_replace_node:
new_const_name = "_FOLDED_CONST_" + node.name
replace_node_with_constant(
gm,
node,
(
const_result[const_outputs[node.name]] # type:ignore[index]
if lifted_constant_names is None
else None
),
new_const_name,
)
const_output_index[new_const_name] = const_outputs[node.name]
for node in to_erase_node[::-1]:
if node.users:
for n in node.users:
assert n.meta[META_TAG] == MODULE_TAG, f"node: {node} user not empty."
else:
gm.graph.erase_node(node)
gm.recompile()
return const_gm, const_output_index
def is_tf32_warning_applicable(gm: GraphModule) -> bool:
aten = torch.ops.aten
tf32_ops = OrderedSet(
[
aten.mm.default,
aten.addmm.default,
aten.bmm.default,
aten.baddbmm.default,
]
)
for target in tf32_ops:
for node in gm.graph.find_nodes(op="call_function", target=target):
if (
isinstance(node.meta.get("val", None), torch.Tensor)
and node.meta["val"].dtype == torch.float32
and node.meta["val"].device.type == "cuda"
):
return True
return False
def maybe_disable_comprehensive_padding(
example_inputs: Sequence[InputType],
) -> AbstractContextManager[None, None]:
"""
For CPU backend, enable comprehensive padding causes some unit tests
fail due to changing number of generated kernels. Skip for now.
"""
has_gpu = any(
is_gpu(t.device.type) for t in example_inputs if isinstance(t, torch.Tensor)
)
if config.disable_padding_cpu and config.comprehensive_padding and not has_gpu:
perf_hint_log.info("Skip comprehensive padding on CPU")
return config.patch(comprehensive_padding=False)
elif config.aot_inductor.use_runtime_constant_folding:
perf_hint_log.info(
"Skip comprehensive padding for use_runtime_constant_folding"
)
return config.patch(comprehensive_padding=False)
else:
return contextlib.nullcontext()
def maybe_disable_graph_partition(
cpp_wrapper: bool, aot_mode: bool
) -> AbstractContextManager[None, None]:
"""
graph partition does not support cpp_wrapper and aot_mode yet.
"""
if cpp_wrapper or aot_mode:
return config.patch(graph_partition=False)
else:
return contextlib.nullcontext()
def fake_tensor_prop(
gm: GraphModule,
example_inputs: Sequence[InputType],
force_allow_non_fake_inputs: bool = False,
) -> torch._subclasses.FakeTensorMode:
"""
If we can not detect fake mode from the context of inputs, create one.
The created fake mode will be returned.
"""
# Ensure that decomps that support symbolic shapes are used
with enable_python_dispatcher():
fake_mode = detect_fake_mode(example_inputs)
if not fake_mode:
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
else:
ctx = (
contextlib.nullcontext()
if not force_allow_non_fake_inputs
else mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
)
with ctx: # type: ignore[attr-defined]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
*example_inputs
)
return fake_mode
# pass config dict back to user
def get_patched_config_dict(
config_patches: Optional[Union[str, dict[str, Any]]] = None,
) -> dict[str, Any]:
with config.patch(config_patches):
return config.get_config_copy()
@contextlib.contextmanager
def with_fresh_cache_if_config() -> Generator[None, None, None]:
if config.force_disable_caches:
# Don't delete the cache dir because it has to survive beyond the
# compile_fx call. Let's put the temp dirs under the default cache
# dir so they're easier to locate.
with fresh_cache(dir=cache_dir(), delete=False):
yield
else:
yield
class _CompileFxKwargs(TypedDict, total=False):
cudagraphs: Optional[BoxedBool]
static_input_idxs: Sequence[int]
is_backward: bool
graph_id: Optional[int]
cpp_wrapper: bool
aot_mode: bool
is_inference: bool
layout_opt: Optional[bool]
extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]]
boxed_forward_device_index: Optional[BoxedDeviceIndex]
class _CompileFxCallable(Protocol):
def __call__(
self,
gm: GraphModule,
example_inputs: Sequence[InputType],
**kwargs: Unpack[_CompileFxKwargs],
) -> OutputCode: ...
def compile_fx_inner(
gm: GraphModule,
example_inputs: Sequence[InputType],
**kwargs: Unpack[_CompileFxKwargs],
) -> OutputCode:
kwargs.setdefault("cudagraphs", None)
kwargs.setdefault("static_input_idxs", ())
kwargs.setdefault("is_backward", False)
kwargs.setdefault("graph_id", None)
kwargs.setdefault("cpp_wrapper", False)
kwargs.setdefault("is_inference", False)
kwargs.setdefault("boxed_forward_device_index", None)
kwargs.setdefault("layout_opt", None)
kwargs.setdefault("extern_node_serializer", None)
# Need with_fresh_cache_if_config for compile_fx_inner even if we already have one for
# compile_fx. The reason is the compilation for backward graph may happen after
# compile_fx return and we may want to use the _LazyGraphModule for compiling
# the backward graph as well.
with contextlib.ExitStack() as stack:
stack.enter_context(torch.utils._python_dispatch._disable_current_modes())
stack.enter_context(_use_lazy_graph_module(dynamo_config.use_lazy_graph_module))
stack.enter_context(
dynamo_utils.dynamo_timed(
"compile_fx_inner",
phase_name="inductor_compile",
log_pt2_compile_event=True,
log_waitcounter=True,
waitcounter_name_override="compile_inductor",
dynamo_compile_column_us="inductor_cumulative_compile_time_us",
)
)
stack.enter_context(with_fresh_cache_if_config())
stack.enter_context(DebugContext())
CompileEventLogger.pt2_compile(
"inductor_compile",
is_backward=kwargs["is_backward"],
)
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
gm,
example_inputs,
**kwargs,
)
@time_and_log(attr="compilation time (in seconds)")
def _compile_fx_inner(
gm: GraphModule,
example_inputs: Sequence[InputType],
**graph_kwargs: Unpack[_CompileFxKwargs],
) -> OutputCode:
"""
Inductor API that compiles a single graph.
If you change the argument list for this function, make sure you
also update the call to save_args_for_compile_fx_inner below accordingly.
"""
aot_mode: bool = V.aot_compilation
# Clean up Compiled Triton Kernels per inductor compile, as the future objects
# may not be valid for use after they are run/autotuned
torch._inductor.async_compile.CompiledTritonKernels.cache_clear()
if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode:
# trigger the real recompilation for _LazyGraphModule before returning
# the forward method.
from torch._dynamo.utils import CompileEventLogLevel
from torch.fx._lazy_graph_module import _LazyGraphModule
_LazyGraphModule.force_recompile(gm)
compile_id = torch._guards.CompileContext.current_compile_id()
CompileEventLogger.log_instant_event(
"backward no-op",
metadata={"compile_id": compile_id},
log_level=CompileEventLogLevel.PT2_COMPILE,
)
return make_boxed_func(gm.forward)
static_input_idxs: Sequence[int] = graph_kwargs.setdefault("static_input_idxs", ())
static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs)
inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs)
assert isinstance(next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)), (
f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
)
if graph_kwargs.get("cudagraphs") is None:
graph_kwargs["cudagraphs"] = BoxedBool(config.triton.cudagraphs)
if config.save_args:
save_args_for_compile_fx_inner(
gm,
example_inputs,
**graph_kwargs,
)
start = time.time()
fx_graph_remote_cache = should_use_remote_fx_graph_cache()
# Check if the registered backend(s) support caching.
init_backend_registration()
backends_support_caching = all(
backend.supports_caching
for backend in (
get_wrapper_codegen_for_device(device.type, config.cpp_wrapper)
for device in get_all_devices(gm)
)
if backend is not None
)
with dynamo_timed(
"fx_codegen_and_compile", log_pt2_compile_event=True, log_waitcounter=True
):
use_cache = (
not config.force_disable_caches
and (config.fx_graph_cache or fx_graph_remote_cache)
and not aot_mode
and backends_support_caching
and not torch._functorch.config.bundled_autograd_cache
)
local = config.fx_graph_cache
remote = fx_graph_remote_cache
set_feature_use("fx_cache", use_cache)
log.debug(
"FX cache status: use_cache=%s, local=%s, remote=%s, aot_mode=%s, force_disable_caches=%s",
use_cache,
local,
remote,
aot_mode,
config.force_disable_caches,
)
# TODO: This is a hack purely to get some info to extract_tensor_metadata_for_cache_key,
# figure out how to not have to modify example inputs
for i, input in enumerate(example_inputs):
if (
isinstance(input, torch.Tensor)
and is_gpu(input.device.type)
and i in static_input_idxs
):
input._is_inductor_static = True # type: ignore[attr-defined]
mb_compiled_graph: Optional[OutputCode] = None
key_info = None
cache_info = None
remote_cache = None
constants = CompiledFxGraphConstantsWithGm(gm)
# TODO: this time will be slightly inconsistent with the one computed
# in prepare_key/load_with_key, dump those settings of "cache_event_time"
start_time = time.time_ns()
if use_cache:
(key_info, cache_info) = FxGraphCache.prepare_key(
gm, example_inputs, graph_kwargs, inputs_to_check, remote
)
# Attempt a cache lookup
if key_info is not None:
key, debug_lines = key_info
log.debug("FX cache key generated: %s", key)
if remote:
remote_cache = FxGraphCache.get_remote_cache()
log.debug("Using remote FX cache")
mb_compiled_graph, cache_info = FxGraphCache.load_with_key(
key,
debug_lines,
example_inputs,
local,
remote_cache,
is_backward=graph_kwargs.get("is_backward", False),
constants=constants,
)
else:
log.debug("Failed to generate FX cache key")
if torch._functorch.config.bundled_autograd_cache:
assert mb_compiled_graph is None
assert cache_info is None
# When using bundled autograd cache, we still want
# to use the TritonBundler, but we don't want to save
# the results here. The results will get saved directly
# to AOTAutogradCache.
TritonBundler.begin_compile()
try:
mb_compiled_graph = fx_codegen_and_compile(
gm, example_inputs, inputs_to_check, **graph_kwargs
)
assert mb_compiled_graph is not None
(
triton_bundle,
triton_bundler_meta,
) = TritonBundler.collect()
mb_compiled_graph.set_triton_bundle(triton_bundle)
except (ShortenTraceback, SkipFrame):
raise
except Exception as e:
raise InductorError(e, currentframe()).with_traceback(
e.__traceback__
) from None
finally:
TritonBundler.end_compile()
# CACHE BYPASS: Compile the graph, don't save it to the cache
# (this can happen either because cache was disabled, or we
# determined the input is uncacheable)
elif cache_info is None or cache_info["cache_state"] == "bypass":
assert mb_compiled_graph is None
log.debug(
"FX cache bypass reason: %s",
(
cache_info.get("cache_bypass_reason", "unknown")
if cache_info is not None
else "FX cache disabled or key generation failed"
),
)
mb_compiled_graph = fx_codegen_and_compile(
gm, example_inputs, inputs_to_check, **graph_kwargs
)
# CACHE MISS: Compile the graph and save to cache
elif cache_info["cache_state"] == "miss":
assert mb_compiled_graph is None
assert key_info is not None
log.debug("FX cache miss, compiling and saving to cache")
TritonBundler.begin_compile()
try:
mb_compiled_graph = fx_codegen_and_compile(
gm, example_inputs, inputs_to_check, **graph_kwargs
)
assert mb_compiled_graph is not None
mb_compiled_graph._time_taken_ns = time.time_ns() - start_time
cache_key, debug_lines = key_info
mb_compiled_graph._fx_graph_cache_key = cache_key
mb_compiled_graph._fx_graph_cache_debug_lines = debug_lines
(
triton_bundle,
triton_bundler_meta,
) = TritonBundler.collect()
mb_compiled_graph.set_triton_bundle(triton_bundle)
except (ShortenTraceback, SkipFrame):
raise
except Exception as e:
raise InductorError(e, currentframe()).with_traceback(
e.__traceback__
) from None
finally:
TritonBundler.end_compile()
if triton_bundler_meta is not None:
cache_info["triton_bundler_meta"] = str(triton_bundler_meta)
cache_info["time_taken_ns"] = mb_compiled_graph._time_taken_ns
log.debug("Saving compiled graph to FX cache with key: %s", cache_key)
FxGraphCache._save_graph(
cache_key,
mb_compiled_graph,
example_inputs,
local,
remote_cache,
)
# CACHE HIT: not much to really do, just make sure the cache key
# is recorded on the graph
else:
assert cache_info["cache_state"] == "hit"
assert mb_compiled_graph is not None
assert key_info is not None
(cache_key, debug_lines) = key_info
log.debug("FX cache hit with key: %s", cache_key)
mb_compiled_graph._fx_graph_cache_key = cache_key
mb_compiled_graph._fx_graph_cache_debug_lines = debug_lines
assert mb_compiled_graph is not None
compiled_graph = mb_compiled_graph
# Logging and observability: we log a single chromium event
# and a tlparse log for every cache action.
# In the event of a bypass, we also logged to the remote table earlier
# with log_cache_bypass.
cache_state = (
cache_info["cache_state"] if cache_info is not None else "disabled"
)
# Here for grepping:
# fx_graph_cache_hit
# fx_graph_cache_miss
# fx_graph_cache_bypass
# fx_graph_cache_disabled
CompileEventLogger.instant(
f"fx_graph_cache_{cache_state}",
metadata=cache_info or {},
time_ns=start_time,
)
# Add event data about cache hits/miss
# TODO: add remote cache get/put timings here too
CompileEventLogger.pt2_compile(
"inductor_compile",
cache_state=cache_state,
cache_event_time=start_time,
key=cache_info.get("key") if cache_info else None,
components=cache_info.get("components") if cache_info else None,
cache_bypass_reason=(
cache_info.get("cache_bypass_reason")
if cache_info
else "cache not enabled"
),
remote_cache_enabled=remote,
local_cache_enabled=local,
)
# Don't clog up the main tlparse output with disabled cache
if cache_info is not None:
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": f"fx_graph_cache_{cache_state}",
"encoding": "json",
},
payload_fn=lambda: json.dumps(cache_info),
)
compiled_graph.post_compile(example_inputs, constants, graph_kwargs)
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
if config.trace.provenance_tracking:
# Dump provenance artifacts for debugging trace
provenance_info = torch._inductor.debug.dump_inductor_provenance_info()
# provenance_info might be None if trace.provenance_tracking is not set
if provenance_info:
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_provenance_tracking_node_mappings",
"encoding": "json",
},
payload_fn=lambda: json.dumps(provenance_info),
)
# This message is for printing overview information of inductor mm counts, shapes,etc after lowering
if log.isEnabledFor(logging.INFO):
mm_table_data = []
for key, value in counters["aten_mm_info"].items():
parts = key.split("_")
if len(parts) < 3:
# Unexpected format, show as-is
mm_table_data.append([key, "-", "?", "?", "?", value])
continue
# Determine if this is a batched operation by checking the operation name
name = "_".join(parts[:-4]) if len(parts) >= 4 else "_".join(parts[:-3])
is_batched = name.endswith(("bmm", "baddbmm"))
if is_batched and len(parts) >= 4:
# Batched operation: last 4 parts are batch, m, n, k
batch, m, n, k = parts[-4:]
name = "_".join(parts[:-4])
mm_table_data.append([name, batch, m, n, k, value])
else:
# Non-batched operation: last 3 parts are m, n, k
m, n, k = parts[-3:]
name = "_".join(parts[:-3])
mm_table_data.append([name, "-", m, n, k, value])
log.info("Overview info of inductor aten mms: ")
log.info(
"{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format( # noqa: G001
"Name", "B", "M", "N", "K", "Count"
)
)
log.info("-" * 130)
for row in mm_table_data:
log.info("{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001
log.info("-" * 130)
# Not strictly necessary, but good to clean up straggling futures
# that are unused to reclaim memory.
torch._inductor.async_compile.CompiledTritonKernels.cache_clear()
_step_logger()(
logging.INFO,
"torchinductor done compiling "
f"{'BACKWARDS' if graph_kwargs['is_backward'] else 'FORWARDS'} "
f"graph {graph_kwargs['graph_id']}",
)
return compiled_graph
class _FxCompileStat:
# Count of successful compiles of this type
codegen_and_compile: int = 0
def __repr__(self) -> str:
return f"codegen_and_compile: {self.codegen_and_compile}"
class FxCompile(ABC):
"""
An FxCompile represents a mechanism that can turn a GraphModule into an
OutputCode.
"""
# Some stats for logging/debugging
_compile_stats: dict[type[FxCompile], _FxCompileStat] = defaultdict(_FxCompileStat)
# TODO: We should probably eventually add some kind of async version of this
# so we can kick off a compile and then go do other things - but we'll need
# to know what kind of API we want for that first.
@abstractmethod
def codegen_and_compile(
self,
gm: GraphModule,
example_inputs: Sequence[InputType],
inputs_to_check: Sequence[int],
graph_kwargs: _CompileFxKwargs,
) -> OutputCode: ...
@classmethod
def _reset_stats(cls) -> None:
cls._compile_stats.clear()
class _InProcessFxCompile(FxCompile):
@override
def codegen_and_compile(
self,
gm: GraphModule,
example_inputs: Sequence[InputType],
inputs_to_check: Sequence[int],
graph_kwargs: _CompileFxKwargs,
) -> OutputCode:
"""
Generates the OutputCode from the GraphModule and example_inputs.
"""
# Sorry about the mess, we need graph_kwargs to continue to be able
# to propagate it further on
# TODO: _CompileFxKwargs actually has stronger types than in the
# signature, need to tighten it up
assert "cudagraphs" in graph_kwargs and graph_kwargs["cudagraphs"] is not None
cudagraphs: BoxedBool = graph_kwargs["cudagraphs"]
static_input_idxs: Sequence[int] = graph_kwargs.get("static_input_idxs", ())
is_backward: bool = graph_kwargs.get("is_backward", False)
graph_id: Optional[int] = graph_kwargs.get("graph_id", None)
cpp_wrapper: bool = graph_kwargs.get("cpp_wrapper", False)
aot_mode: bool = V.aot_compilation
is_inference: bool = graph_kwargs.get("is_inference", False)
extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] = (
graph_kwargs.get("extern_node_serializer", None)
)
with (
_WaitCounter("pytorch.wait_counter.actual_codegen_and_compile").guard(),
dynamo_utils.preserve_rng_state(),
):
if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None:
import time
log.warning(
"Sleeping for %s since sleep_sec_TESTING_ONLY is set", sleep_sec
)
time.sleep(sleep_sec)
if is_tf32_warning_applicable(gm):
_warn_tf32_disabled()
inductor_counters = counters["inductor"].copy()
# lift the maximum depth of the Python interpreter stack
# to adapt large/deep models
sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
_step_logger()(
logging.INFO,
"torchinductor compiling "
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
f"graph {graph_id}",
)
fd = io.StringIO()
torch._dynamo.repro.after_aot.save_graph_repro(
fd, gm, example_inputs, "inductor", save_dir=None
)
runnable_graph_str = fd.getvalue()
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "fx_graph_runnable",
"encoding": "string",
},
payload_fn=lambda: runnable_graph_str,
)
V.debug.fx_graph(gm, example_inputs)
# TODO: Should we actually dump this? It should be redundant with the aot
# structured logs...
# trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False))
shape_env = shape_env_from_inputs(example_inputs)
# Convert view to reshape in the graph. This is necessary primarily for
# layout optimization. Do it unconditionally for uniformity.
#
# It's needed because when we do layout optimization, an contiguous tensor
# in eager mode may becomes a channels last tensor. A view op previously
# can be applied to the contiguous tensor may not be able to be applied
# on the channels tensor any more. An error like
# RuntimeError: view size is not compatible with input tensor's size and stride
# (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
# will be printed.
#
# Replace view op to reshape op in this case.
# As an example, timm_resnest/botnet26t_256/convnext_base etc. will fail if we don't do this.
#
# Also this has to be done before FakeTensorProp below to avoid the failed
# .view() call.
view_to_reshape(gm)
with dynamo_timed(
"additional_fake_tensor_prop", log_pt2_compile_event=True
):
# It is safe to run FakeTensorProp under no_grad because by the time
# we're in inductor, we assume that AOTAutograd has already "taken care"
# of autograd, so there should be no more autograd-related API's in the
# graph.
with torch.no_grad():
fake_mode = fake_tensor_prop(gm, example_inputs)
_recursive_record_original_output_strides(gm)
# pattern matcher passes might not preserve striding information
# on node.meta["val"]. if in the future we rely on these being
# correct we will need to fix.
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "before_post_grad_graph",
"encoding": "string",
},
payload_fn=lambda: gm.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
with V.set_fake_mode(fake_mode):
# has some issues with memory in training
cuda_context = get_cuda_device_context(gm)
with cuda_context:
_recursive_post_grad_passes(gm, is_inference=is_inference)
V.debug.fx_graph_transformed(gm, example_inputs)
post_grad_graphs_log.debug(
"%s",
lazy_format_graph_code(
"AFTER POST GRAD",
gm,
include_stride=True,
include_device=True,
colored=True,
),
)
# We're printing the graph to be used as a cache key - so a
# printer which is a little less readable but faster is
# appropriate.
inductor_post_grad_graph_str = gm.print_readable(
print_output=False,
include_stride=True,
include_device=True,
fast_sympy_print=True,
)
# "after_post_grad_graph" is used in inductor provenance
# tracking highlighter front-end.
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "after_post_grad_graph",
"encoding": "string",
},
payload_fn=lambda: inductor_post_grad_graph_str,
)
if config.trace.provenance_tracking:
provenance_tracking_json = (
torch.fx.traceback.get_graph_provenance_json(gm.graph)
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_post_to_pre_grad_nodes",
"encoding": "json",
},
payload_fn=lambda: json.dumps(provenance_tracking_json),
)
from torch._inductor.debug import create_mapping_pre_post_grad_nodes
torch._inductor.debug._inductor_post_to_pre_grad_nodes = (
create_mapping_pre_post_grad_nodes(
torch._inductor.debug._pre_grad_graph_id,
provenance_tracking_json,
)
)
metrics_context = get_metrics_context()
if metrics_context.in_progress():
# TODO: Remove this when 3.9 is no longer supported
if sys.version_info < (3, 10):
num_graph_breaks = sum(counters["graph_break"].values())
else:
num_graph_breaks = counters["graph_break"].total()
CompileEventLogger.compilation_metric(
overwrite=True, num_graph_breaks=num_graph_breaks
)
if config.is_fbcode():
try:
log_optimus_to_scuba(
extra_logging={
"pt2_configs": str(get_patched_config_dict())
}
)
except Exception:
# TODO(T216453900): need to work around for now to support vllm
# See details in vllm/compilation/pass_manager.py.
log.warning("failed to log pt2_configs")
with (
V.set_fake_mode(fake_mode),
maybe_disable_comprehensive_padding(example_inputs),
maybe_disable_graph_partition(cpp_wrapper, aot_mode),
):
const_output_index = None
const_graph = None
const_wrapper_code = None
const_kernel_code = None
if aot_mode and config.aot_inductor.use_runtime_constant_folding:
# torchbind objects have name that starts with _torchbind_obj
# See caffe2/torch/fx/_symbolic_trace.py?lines=406
const_gm, const_output_index = split_const_gm(
gm,
skip_folding_node_fn=lambda node: node.op == "get_attr"
and isinstance(node.target, str)
and (
node.target.startswith("_torchbind_obj")
or isinstance(node.meta.get("val", None), FakeScriptObject)
),
)
const_graph = GraphLowering(
const_gm,
example_inputs=[],
shape_env=shape_env,
graph_id=graph_id,
cpp_wrapper=cpp_wrapper,
aot_mode=aot_mode,
extern_node_serializer=extern_node_serializer,
is_inference=is_inference,
is_backward=is_backward,
is_const_graph=True,
)
with V.set_graph_handler(const_graph):
assert cpp_wrapper, "AOT mode only supports C++ wrapper"
const_graph.run()
const_wrapper_code, const_kernel_code = (
const_graph.codegen_with_cpp_wrapper()
)
graph = GraphLowering(
gm,
# example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.
# For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,
# we currently use fake tensors and defake them later.
example_inputs=example_inputs,
shape_env=shape_env,
graph_id=graph_id,
cpp_wrapper=cpp_wrapper,
aot_mode=aot_mode,
extern_node_serializer=extern_node_serializer,
is_inference=is_inference,
is_backward=is_backward,
const_output_index=const_output_index,
const_wrapper_code=(
const_wrapper_code.value if const_wrapper_code else None
),
const_kernel_code=(
const_kernel_code.value if const_kernel_code else None
),
const_module=const_graph,
inputs_to_check=inputs_to_check,
)
metrics_helper = metrics.CachedMetricsHelper()
# We are going to start code generating runtime asserts, so make sure
# you don't start adding new ones in the lowering process
graph.freeze_runtime_asserts()
with V.set_graph_handler(graph):
graph.run(*example_inputs)
output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = []
if graph.graph_outputs is not None:
# We'll put the output strides in the compiled graph so we
# can later return them to the caller via TracingContext
p = SymExprPrinter()
for out in graph.graph_outputs:
if (
isinstance(out, IRNode)
and out.has_tensor_output()
and len(free_unbacked_symbols(out.get_stride())) == 0
):
# Convert to string for eval on the load path
output_strides.append(
tuple(p.doprint(s) for s in out.get_layout().stride)
)
else:
output_strides.append(None)
_check_triton_bf16_support(graph)
# TODO: The switching between AOT mode and not here is a bit
# messy, but it's localized to the block of code below so I'm
# not going to touch it for now
compiled_fn: Any
compiled_fn_runner = None
with dynamo_timed(
"GraphLowering.compile_to_fn", log_pt2_compile_event=True
):
if graph.aot_mode:
from .codecache import AotCodeCompiler
assert graph.cpp_wrapper, (
"AOT mode only supports C++ wrapper"
)
wrapper_code, kernel_code = graph.codegen_with_cpp_wrapper()
output_code_log.debug(
"Output wrapper code: \n%s", wrapper_code.value
)
if kernel_code.value:
output_code_log.debug(
"Output kernel code:\n%s", kernel_code.value
)
serialized_extern_kernel_nodes = None
if graph.extern_kernel_nodes:
serialized_extern_kernel_nodes = (
graph.extern_node_serializer(
graph.extern_kernel_nodes
)
)
output_code_log.debug(
"Serialized Extern Kernel Nodes: \n%s",
serialized_extern_kernel_nodes,
)
with dynamo_timed(
"AotCodeCompiler.compile", log_pt2_compile_event=True
):
# Directly return the file path with the compiled code
compiled_fn = AotCodeCompiler.compile(
graph,
wrapper_code.value,
kernel_code.value,
serialized_extern_kernel_nodes,
device_type=graph.device_type,
additional_files=[
*dict.fromkeys(
graph.wrapper_code.additional_files
+ (
const_graph.wrapper_code.additional_files
if const_graph
else []
)
)
],
)
else:
compiled_module = graph.compile_to_module()
compiled_fn = compiled_module.call
compiled_fn_runner = getattr(
compiled_module, "runner", None
)
if inductor_metrics_log.isEnabledFor(logging.INFO):
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
metrics.num_bytes_accessed += num_bytes
metrics.node_runtimes += node_runtimes
metrics.nodes_num_elem += nodes_num_elem
inductor_metrics_log.info(
"Graph Metrics:\n%s",
{
"num_bytes_accessed": num_bytes,
"nodes_num_elem": nodes_num_elem,
"node_runtimes": node_runtimes,
},
)
if (
cudagraphs
and config.triton.cudagraph_skip_dynamic_graphs
and not V.graph.disable_cudagraphs_reason
and torch._inductor.utils.any_is_symbolic(*example_inputs)
):
stack_trace = None
for node in gm.graph.nodes:
meta_val = node.meta.get("val", None)
if (
node.op == "placeholder"
or not isinstance(meta_val, torch.Tensor)
or not torch._inductor.utils.any_is_symbolic(meta_val)
):
continue
if stack_trace := node.meta.get("stack_trace", None):
break
disable = "graph with symbolic shapes inputs and config.triton.cudagraph_skip_dynamic_graphs=True."
if stack_trace:
disable = f"{disable} Found from {stack_trace}\n"
else:
disable = f"{disable}\n"
V.graph.disable_cudagraphs_reason = disable
if cudagraphs and not V.graph.disable_cudagraphs_reason:
maybe_incompat_node = get_first_incompatible_cudagraph_node(gm)
if maybe_incompat_node:
disable = f"disabling cudagraphs due to incompatible op {maybe_incompat_node.target}"
if stack_trace := maybe_incompat_node.meta.get(
"stack_trace", None
):
disable = f"{disable} Found from {stack_trace}\n"
V.graph.disable_cudagraphs_reason = disable
if V.aot_compilation:
assert isinstance(compiled_fn, (str, list))
return CompiledAOTI(compiled_fn)
# TODO: Hoist this above V.aot_compilation
if cudagraphs and not V.graph.disable_cudagraphs_reason:
from torch._inductor.cudagraph_utils import (
check_lowering_disable_cudagraph,
)
V.graph.disable_cudagraphs_reason = (
check_lowering_disable_cudagraph(
V.graph.device_node_mapping
)
)
self._compile_stats[type(self)].codegen_and_compile += 1
return CompiledFxGraph(
compiled_fn,
graph,
gm,
output_strides,
V.graph.disable_cudagraphs_reason,
metrics_helper.get_deltas(),
counters["inductor"] - inductor_counters,
cudagraphs,
example_inputs,
static_input_idxs,
graph_kwargs,
inputs_to_check,
runnable_graph_str,
inductor_post_grad_graph_str,
compiled_fn_runner,
)
def fx_codegen_and_compile(
gm: GraphModule,
example_inputs: Sequence[InputType],
# This is derivable from the other inputs to this function, but we pass it
# in explicitly because it's nontrivial to compute
inputs_to_check: Sequence[int],
**graph_kwargs: Unpack[_CompileFxKwargs],
) -> OutputCode:
scheme: FxCompile
if fx_compile_mode == FxCompileMode.NORMAL:
scheme = _InProcessFxCompile()
elif fx_compile_mode == FxCompileMode.SERIALIZE:
from .compile_fx_ext import _DebugSerdeFxCompile
scheme = _DebugSerdeFxCompile()
elif fx_compile_mode == FxCompileMode.SUBPROCESS:
from .compile_fx_subproc import _SubprocessFxCompile
scheme = _SubprocessFxCompile()
if fx_compile_async:
from .compile_fx_async import _AsyncFxCompile
from .compile_fx_ext import _OutOfProcessFxCompile
assert isinstance(scheme, _OutOfProcessFxCompile), (
"async is only valid with an out-of-process compile mode"
)
scheme = _AsyncFxCompile(scheme)
if fx_compile_progressive:
from .compile_fx_async import _ProgressiveFxCompile
from .compile_fx_ext import _OutOfProcessFxCompile
assert isinstance(scheme, _OutOfProcessFxCompile), (
"progressive is only valid with an out-of-process compile mode"
)
progression_configs = _get_progression_configs()
# Use in-process compile for the fast version
fast_scheme = _InProcessFxCompile()
scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs)
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
def get_input_idxs_to_check(
inputs: Sequence[InputType],
static_input_idxs: Sequence[int],
) -> Sequence[int]:
"""
This function runs at compile time, and generates a list of indices for which we
might need to do a copy to preserve alignment requirements.
"""
ids_to_check = []
for i, input in enumerate(inputs):
if not isinstance(input, torch.Tensor):
# non-tensors don't need alignment
continue
if not is_gpu(input.device.type):
# right now we only care for gpu tensors
continue
with maybe_get_suppress_shape_guards_ctx():
# suppress guards so that tensor_is_aligned and should_assume_input_aligned
# do not add guards on input's storage offset
if i in static_input_idxs and tensor_is_aligned(input):
continue
if not should_assume_input_aligned(input):
continue
# if we get here, then
# (a) our triton code assumes that the input is aligned
# (b) we can't be sure ahead of time that the input will actually be aligned.
# therefore, at runtime, we'll need to check that the input is aligned
# (and if not, clone it to make it aligned.)
ids_to_check.append(i)
return ids_to_check
def cudagraphify(
model: Callable[..., Any],
static_input_idxs: Sequence[int] = (),
*,
device_index: int,
stack_traces: list[Optional[str]],
is_backward: bool,
is_inference: bool,
constants: tuple[torch.Tensor, ...] = (),
placeholders: Sequence[PlaceholderInfo] = (),
mutated_input_idxs: tuple[int, ...] = (),
) -> Callable[..., Any]:
from torch._inductor.cudagraph_trees import (
cudagraphify_impl as new_cudagraphify_impl,
)
cudagraphify_fn: Callable[..., Any]
if config.triton.cudagraph_trees:
cudagraphify_fn = functools.partial(
new_cudagraphify_impl,
device_index=device_index,
stack_traces=stack_traces,
is_backward=is_backward,
is_inference=is_inference,
constants=constants,
placeholders=placeholders,
mutated_input_idxs=mutated_input_idxs,
compile_id=torch._guards.CompileContext.current_compile_id(),
)
else:
cudagraphify_fn = cudagraphify_impl
compiled_fn = None
def run(new_inputs: Sequence[InputType]) -> Any:
nonlocal compiled_fn
if compiled_fn is None:
with dynamo_utils.preserve_rng_state():
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs) # type: ignore[arg-type]
return compiled_fn(new_inputs) # type: ignore[arg-type]
return run
def static_input(x: torch.Tensor) -> torch.Tensor:
"""
Copy and input while preserving strides
"""
return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
def index_expanded_dims_and_copy_(
dst: torch.Tensor,
src: torch.Tensor,
expanded_dims: list[int],
) -> None:
"Index into expanded dimensions of both dst and src then copy_"
dst = index_expanded_dims(dst, expanded_dims)
src = index_expanded_dims(src, expanded_dims)
dst.copy_(src)
def cudagraphify_impl(
model: Callable[..., Any],
inputs: list[torch.Tensor],
static_input_idxs: Sequence[int] = (),
) -> Callable[[list[InputType]], Any]:
"""
Assumes inputs[static_input_idxs[i]] are always the same memory address
"""
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type]
static_input_idxs: OrderedSet[int] = OrderedSet(
remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type]
)
copy_misaligned_inputs(inputs, check_input_idxs) # type: ignore[arg-type]
assert isinstance(inputs, list)
inps_expanded_dims = [
get_expanded_dims(x) if idx not in static_input_idxs else []
for idx, x in enumerate(inputs)
]
# allocate static tensor inputs
static_inputs = [
(
x
if not isinstance(x, torch.Tensor)
else static_input(x)
if idx not in static_input_idxs
else x.detach()
)
for idx, x in enumerate(inputs)
]
# copy over input values for fresh allocations
for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)):
if isinstance(x, torch.Tensor) and idx not in static_input_idxs:
index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims)
# warmup
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
# copy static_inputs because it will be cleared in model
with torch.cuda.stream(stream):
model(list(static_inputs))
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()
# record
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"):
static_outputs = model(list(static_inputs))
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
if config.size_asserts:
def run(new_inputs: list[InputType]) -> Callable[[list[InputType]], Any]:
assert len(static_inputs) == len(new_inputs)
for idx, (dst, src, expanded_dims) in enumerate(
zip(static_inputs, new_inputs, inps_expanded_dims)
):
if not isinstance(dst, torch.Tensor):
continue
assert isinstance(src, torch.Tensor)
if idx in static_input_idxs:
assert dst.data_ptr() == src.data_ptr()
else:
# TODO - could make one single op of multiple slices
# and avoid dispatch.
# Could also pre-index the `dst` tensors
index_expanded_dims_and_copy_(dst, src, expanded_dims)
new_inputs.clear()
graph.replay()
return static_outputs
else:
copy_indices = [
idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
]
def run(new_inputs: list[InputType]) -> Callable[[list[InputType]], Any]:
for idx in copy_indices:
expanded_dims = inps_expanded_dims[idx]
src = new_inputs[idx]
assert isinstance(src, torch.Tensor)
index_expanded_dims_and_copy_(static_inputs[idx], src, expanded_dims)
new_inputs.clear()
graph.replay()
return static_outputs
return align_inputs_from_check_idxs(run, check_input_idxs, OrderedSet())
def compile_fx_aot(
model_: GraphModule,
example_inputs_: list[InputType],
inner_compile: _CompileFxCallable = compile_fx_inner,
config_patches: Optional[dict[str, Any]] = None,
) -> Union[list[Union[str, Weights]], str]:
assert isinstance(model_, GraphModule), model_
# [See NOTE] Unwrapping subclasses AOT
unwrap_tensor_subclass_parameters(model_)
config_patches: dict[str, Any] = (
{"cpp_wrapper": True}
if config_patches is None
else {**config_patches, "cpp_wrapper": True}
)
output_path = config_patches.get(
"aot_inductor.output_path", config.aot_inductor.output_path
)
if output_path:
assert not output_path.endswith(".pt2"), (
"The output path for aot_compile should not have an extension with .pt2 "
"this is for specifying the output path for the .so in AOTInductor. "
"If you would like to package the AOTInductor generated files "
"into a pt2, please call `torch._inductor.aoti_compile_and_package`."
)
else:
config_patches = {
**config_patches,
"aot_inductor.output_path": code_hash(model_.code),
}
from .utils import maybe_aoti_standalone_config
config_patches = maybe_aoti_standalone_config(config_patches)
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
saved_compile_id = model_.meta.get("dynamo_compile_id", None)
saved_compile_context = torch._guards.CompileContext(saved_compile_id)
with (
V.set_aot_compilation(True),
torch._guards.compile_context(saved_compile_context),
chromium_event_timed(
"compile_fx_aot",
log_pt2_compile_event=True,
reset_event_log_on_exit=True,
),
get_metrics_context(),
):
compiled_artifacts = compile_fx(
model_,
example_inputs_,
inner_compile=functools.partial(
inner_compile,
extern_node_serializer=extern_node_serializer,
),
config_patches=config_patches,
)
assert isinstance(compiled_artifacts, CompiledAOTI)
return compiled_artifacts.filename
_graph_counter = count(0)
def fw_compiler_freezing(
aot_autograd_model: GraphModule,
aot_example_inputs: Sequence[InputType],
dynamo_model: GraphModule,
num_example_inputs: int,
inner_compile: Callable[..., Any],
cudagraphs: BoxedBool,
graph_id: int,
forward_device: BoxedDeviceIndex,
) -> Callable[[list[object]], Sequence[torch.Tensor]]:
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze
# partition_fn won't be called
_recursive_joint_graph_passes(aot_autograd_model)
layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True)
if layout_opt:
# make sure meta['val'] is properly setup
fake_tensor_prop(aot_autograd_model, aot_example_inputs, True)
convert_conv_weights_to_channels_last(aot_autograd_model)
opt_model, preserved_arg_indices = freeze(
dynamo_model,
aot_autograd_model,
aot_example_inputs, # type: ignore[arg-type]
)
aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
fake_mode = detect_fake_mode(aot_example_inputs)
# for freezing, all graph outputs should be user visible
*_, model_outputs_node = opt_model.graph.nodes
model_outputs = model_outputs_node.args[0]
model_outputs_node.meta["user_visible_output_idxs"] = [
idx for idx, n in enumerate(model_outputs) if isinstance(n, torch.fx.Node)
]
static_input_idxs = []
# constant params will be real tensors, not fake
tracing_context = torch._guards.TracingContext.try_get()
unwrapped_args_offsets = [0]
max_offset_idx = 0
if tracing_context is not None:
assert tracing_context.params_flat_unwrap_subclasses is not None
params_flat_unwrap = tracing_context.params_flat_unwrap_subclasses
max_offset_idx = max(0, len(params_flat_unwrap) - 1)
preserved_indices_params_flat = OrderedSet[int]()
unwrapped_idxs = tracing_context.params_unwrapped_to_flat_index
assert unwrapped_idxs is not None
current_offset = 0
if len(params_flat_unwrap) > 0:
unwrapped_args_offsets = []
for i in range(len(params_flat_unwrap)):
if i not in preserved_arg_indices:
params_flat_unwrap[i] = None
if i > 0 and unwrapped_idxs[i] == unwrapped_idxs[i - 1]:
current_offset += 1
else:
preserved_indices_params_flat.add(unwrapped_idxs[i])
unwrapped_args_offsets.append(current_offset)
# Deallocate wrapped params, if all subelements were deallocated
assert tracing_context.params_flat is not None
for i in range(len(tracing_context.params_flat)):
if i not in preserved_indices_params_flat:
tracing_context.params_flat[i] = None
if tracing_context.fw_metadata:
static_input_idxs = tracing_context.fw_metadata.static_input_indices
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
optimized_function = inner_compile(
opt_model,
aot_example_inputs,
static_input_idxs=static_input_idxs,
cudagraphs=cudagraphs,
graph_id=graph_id,
is_inference=True,
boxed_forward_device_index=forward_device,
layout_opt=layout_opt,
)
# aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper
# that drops constant-ified params
if V.aot_compilation:
return optimized_function
def wrapper(args: list[object]) -> Sequence[torch.Tensor]:
args_new = [
args[i - unwrapped_args_offsets[min(i, max_offset_idx)]]
for i in preserved_arg_indices
]
args.clear()
return optimized_function(args_new)
wrapper._boxed_call = True # type: ignore[attr-defined]
return wrapper
def get_cpp_wrapper_config() -> dict[str, object]:
if config.triton.cudagraphs:
log_cudagraph_skip_and_bump_counter(
format_default_skip_message("cpp wrapper enabled")
)
return {
# Set autotune_at_compile_time to True as default if the option is not explicitly set
"triton.autotune_at_compile_time": (
config.triton.autotune_at_compile_time
if config.triton.autotune_at_compile_time is not None
else has_triton()
),
"triton.autotune_cublasLt": False,
"triton.cudagraphs": False, # TODO: to be removed
"triton.store_cubin": True,
}
def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[None]:
"""
Returns a cuda device context manager if there is a single device in the graph
"""
if not torch.cuda.is_available():
return contextlib.nullcontext()
cuda_devices: OrderedSet[torch.device] = OrderedSet(
device for device in get_all_devices(gm) if device.type == "cuda"
)
return (
torch.cuda.device(next(iter(cuda_devices))) # type: ignore[return-value]
if len(cuda_devices) == 1
else contextlib.nullcontext()
)
def compile_fx(
model_: GraphModule,
example_inputs_: Sequence[InputType],
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
config_patches: Optional[dict[str, Any]] = None,
decompositions: Optional[dict[OpOverload, Callable[..., Any]]] = None,
ignore_shape_env: bool = False,
) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str], Weights]:
"""
Main entry point for compiling given FX graph. Despite the fact that this
lives in :mod:`torch._inductor`, this function is responsible for calling
into AOT Autograd (and we will eventually get a callback to
``inner_compile`` to perform actual compilation. In other words, this
function orchestrates end-to-end compilation for the inductor backend when
you use :func:`torch.compile`.
NB: This function TAKES OWNERSHIP of the input ``model_`` and can potentially
mutate it! Make a copy if you need to preserve the original GraphModule.
"""
# Wake up the AsyncCompile subproc pool as early as possible (if there's cuda).
if any(
isinstance(e, torch.Tensor) and e.device.type in ("cuda", "xpu")
for e in example_inputs_
):
torch._inductor.async_compile.AsyncCompile.wakeup()
# Some arguments trigger a recursive call to compile_fx. Handle these
# short circuits first, before anything else
if config_patches:
with config.patch(config_patches):
return compile_fx(
model_,
example_inputs_,
# need extra layer of patching as backwards is compiled out of scope
inner_compile=config.patch(config_patches)(inner_compile),
decompositions=decompositions,
ignore_shape_env=ignore_shape_env,
)
# TODO: This probably shouldn't be a recursive call
if config.cpp_wrapper:
with (
config.patch(
{
"cpp_wrapper": False, # reset to break recursive call to compile_fx
**get_cpp_wrapper_config(),
}
),
V.set_real_inputs(example_inputs_),
):
inputs_: Sequence[InputType] = example_inputs_
if isinstance(model_, GraphModule):
fake_inputs = [
node.meta.get("val")
for node in model_.graph.nodes
if node.op == "placeholder"
]
# Replace non-tensor (constant) inputs with Nones, since these are not being
# used anyways by the graph
fake_inputs = [
inp if isinstance(inp, torch.Tensor) else None
for inp in fake_inputs
]
if any(v is not None for v in fake_inputs):
# Validate devices before switching to fake tensors.
for idx, fi, i in zip(count(), fake_inputs, inputs_):
if fi is not None:
assert isinstance(i, torch.Tensor)
if fi.device != i.device:
raise ValueError(
f"Device mismatch between fake input and example input at position #{idx}: "
f"{fi.device} vs {i.device}. If the model was exported via torch.export(), "
"make sure torch.export() and torch.aot_compile() run on the same device."
)
inputs_ = fake_inputs # type: ignore[assignment]
from torch._export.non_strict_utils import _fakify_script_objects
fake_mode = detect_fake_mode(inputs_)
with _fakify_script_objects(model_, inputs_, {}, fake_mode) as (
patched_mod,
fake_args,
_,
_,
_,
):
return compile_fx(
patched_mod,
fake_args,
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
decompositions=decompositions,
ignore_shape_env=ignore_shape_env,
)
recursive_compile_fx = functools.partial(
compile_fx,
inner_compile=inner_compile,
decompositions=decompositions,
ignore_shape_env=ignore_shape_env,
)
if not graph_returns_tuple(model_):
return make_graph_return_tuple(
model_,
example_inputs_,
recursive_compile_fx,
)
if isinstance(model_, GraphModule) and isinstance(
model_.graph._codegen, _PyTreeCodeGen
):
# this graph is the result of dynamo.export()
return handle_dynamo_export_graph(
model_,
example_inputs_,
recursive_compile_fx,
)
# Do the actual work
with (
_use_lazy_graph_module(dynamo_config.use_lazy_graph_module),
enable_python_dispatcher(),
torch.fx.traceback.preserve_node_meta(config.trace.provenance_tracking),
torch._inductor.debug.reset_provenance_globals(),
):
# Pre-grad passes cannot be run if we weren't given a GraphModule.
# Dynamo will always produce a GraphModule, but this handles cases
# where a user directly passes a plain Module with the intention of
# having AOTAutograd trace it.
# TODO: Get rid of this?
if isinstance(model_, GraphModule):
# "before_pre_grad_graph" is used in inductor provenance
# tracking highlighter front-end.
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "before_pre_grad_graph",
"encoding": "string",
},
payload_fn=lambda: model_.print_readable(
print_output=False, include_stride=True, include_device=True
)
+ f"\n\n # graph id: {id(model_.graph)}",
)
pre_grad_graphs_log.debug(
"%s",
lazy_format_graph_code(
"BEFORE PRE GRAD",
model_,
include_stride=True,
include_device=True,
colored=True,
),
)
torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
if config.trace.provenance_tracking:
for node in model_.graph.nodes:
if node.stack_trace:
torch._inductor.debug._inductor_pre_grad_node_stack_trace[
node.name
] = node.stack_trace
model_ = _recursive_pre_grad_passes(model_, example_inputs_)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "after_pre_grad_graph",
"encoding": "string",
},
payload_fn=lambda: model_.print_readable(
print_output=False, include_stride=True, include_device=True
)
+ f"\n\n # graph id: {id(model_.graph)}",
)
# TODO: Move this before recursive pre-grad passes
# NB: This short circuit never occurs for Dynamo produced graphs
# (which are pre-flattened)
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
return flatten_graph_inputs(
model_,
example_inputs_,
recursive_compile_fx,
)
assert not config._raise_error_for_testing
num_example_inputs = len(example_inputs_)
# Although cudagraphs may have been enabled via config, various
# conditions (which are tested within the bowels of Inductor) may
# force cudagraphs to be disabled. This mutable box lets us retrieve
# the final determination if cudagraphs actually can be used or not.
cudagraphs = BoxedBool(config.triton.cudagraphs)
# See [Backward Generation Handling]
forward_device = BoxedDeviceIndex(None)
# TODO: The modern style is to use CompileId from TracingContext to
# identify Inductor compilation. However, this CompileId cannot
# uniquely identify multiple Inductor compilations that arise from
# DDPOptimizer
graph_id = next(_graph_counter)
decompositions = (
decompositions if decompositions is not None else select_decomp_table()
)
def fw_compiler_base(
gm: GraphModule,
example_inputs: Sequence[InputType],
is_inference: bool,
) -> OutputCode:
with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
if is_inference:
# partition_fn won't be called
_recursive_joint_graph_passes(gm)
fixed = torch._inductor.utils.num_fw_fixed_arguments(
num_example_inputs, len(example_inputs)
)
model_outputs_node = output_node(gm)
if config.keep_output_stride:
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
num_model_outputs = len(model_outputs)
context = torch._guards.TracingContext.try_get()
# See Note [User Outputs in the inductor graph]
if context is not None and context.fw_metadata and not is_inference:
original_output_start_index = (
context.fw_metadata.num_mutated_inp_runtime_indices
)
else:
original_output_start_index = 0
if isinstance(model_, GraphModule):
*_, orig_model_outputs_node = model_.graph.nodes
assert orig_model_outputs_node.op == "output"
orig_model_outputs, _ = pytree.tree_flatten(
orig_model_outputs_node.args
)
num_orig_model_outputs = len(orig_model_outputs)
else:
num_orig_model_outputs = num_model_outputs
assert num_orig_model_outputs <= num_model_outputs
# Note [User Outputs in the inductor graph]
# We makes the following assumption
# For inference
# len(orig_model_outputs) == len(model_outputs)
# For training
# len(orig_model_outputs) <= len(model_outputs)
# During training, most of the time the model_outputs starts with
# original module's outputs followed by saved activations.
# But this can be not true if the model have inplace updated tensors.
# AOTAutograd will make those tensors being returned before the original
# module's output.
# To make things safe, we'll use original_output_start_index field
# set by AOTAutograd to decide where the original module outputs start.
orig_output_end_idx = (
original_output_start_index + num_orig_model_outputs
)
# Sanity check: we are about to splice out the "user" outputs from the full set
# of "graph" outputs. Make sure we're within bounds.
assert orig_output_end_idx <= num_model_outputs
model_outputs_node.meta["user_visible_output_idxs"] = [
idx
for idx in range(
original_output_start_index, orig_output_end_idx
)
if isinstance(model_outputs[idx], torch.fx.Node)
]
else:
model_outputs_node.meta["user_visible_output_idxs"] = []
# We also mark the invoke_subgraph outputs as user_visible to
# force the outputs of invoke_subgraph subgraph to follow the
# original strides
_recursive_record_user_visible_output_idxs(gm)
return inner_compile(
gm,
example_inputs,
static_input_idxs=get_static_input_idxs(fixed),
cudagraphs=cudagraphs,
graph_id=graph_id,
is_inference=is_inference,
boxed_forward_device_index=forward_device,
)
fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = (
functools.partial(fw_compiler_base, is_inference=False)
)
fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler)
if config.freezing and not torch.is_grad_enabled():
inference_compiler: Callable[..., Any] = functools.partial(
fw_compiler_freezing,
dynamo_model=model_,
num_example_inputs=num_example_inputs,
inner_compile=inner_compile,
cudagraphs=cudagraphs,
graph_id=graph_id,
forward_device=forward_device,
)
else:
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
inference_compiler = SerializableAOTDispatchCompiler(
OutputCode, inference_compiler
)
def partition_fn(
gm: GraphModule,
joint_inputs: Sequence[object],
**kwargs: object,
) -> tuple[GraphModule, GraphModule]:
cuda_context = get_cuda_device_context(gm)
with cuda_context:
# We can skip the invoke_subgraph because the
# entire_partition_fn is called recursively for invoke_subgraph
# in partitioning.
_recursive_joint_graph_passes(gm, skip_invoke_subgraph=True)
static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment]
"static_lifetime_input_indices", None
)
with dynamo_utils.dynamo_timed(
"min_cut_rematerialization_partition", log_pt2_compile_event=True
):
return min_cut_rematerialization_partition(
gm,
joint_inputs,
compiler="inductor",
static_lifetime_input_indices=static_lifetime_input_indices,
**kwargs,
)
@compile_time_strobelight_meta(phase_name="backward")
def bw_compiler(
gm: GraphModule, example_inputs: Sequence[InputType]
) -> OutputCode:
from torch._dynamo.convert_frame import compile_lock
with (
dynamo_utils.dynamo_timed("compile_fx.<locals>.bw_compiler"),
compile_lock,
):
model_outputs_node = output_node(gm)
if config.bw_outputs_user_visible:
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
model_outputs_node.meta["user_visible_output_idxs"] = [
idx
for idx, n in enumerate(model_outputs)
if isinstance(n, torch.fx.Node)
]
else:
model_outputs_node.meta["user_visible_output_idxs"] = []
fixed = count_tangents(gm)
with (
config.patch(get_cpp_wrapper_config())
if config.cpp_wrapper
else contextlib.nullcontext()
):
return inner_compile(
gm,
example_inputs,
static_input_idxs=list(range(fixed)),
cudagraphs=cudagraphs,
is_backward=True,
graph_id=graph_id,
boxed_forward_device_index=forward_device,
)
bw_compiler = SerializableAOTDispatchCompiler(OutputCode, bw_compiler)
fake_mode = detect_fake_mode(
example_inputs_
) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
tracing_context = (
torch._guards.TracingContext.try_get()
or torch._guards.TracingContext(fake_mode)
)
if V.aot_compilation:
from .utils import is_valid_aoti_model_name
is_valid_aoti_model_name()
with functorch_config.patch(unlift_effect_tokens=True):
gm, graph_signature = aot_export_module(
model_,
example_inputs_,
trace_joint=False,
decompositions=decompositions,
)
from torch._export.utils import _detect_fake_mode_from_gm
fake_mode = _detect_fake_mode_from_gm(gm)
# aot_export_module doesn't account for constant tensor attributes
# so we end up having tensors that don't have fake vals attached.
# This can happen when upstream export is non-strict where we
# preserve the original module params/buffers. Once AOTI switches
# to ep.run_decompositions() flow to lower to post-autograd opset
# this will go away.
for node in gm.graph.nodes:
if node.op == "get_attr" and "val" not in node.meta:
target = attrgetter(node.target)(gm)
if isinstance(target, torch.Tensor):
node.meta["val"] = fake_mode.from_tensor(
target, static_shapes=True
)
elif isinstance(target, torch.ScriptObject):
node.meta["val"] = (
torch._library.fake_class_registry.maybe_to_fake_obj(
fake_mode, target
)
)
elif isinstance(target, FakeScriptObject):
node.meta["val"] = target
unlifted_gm = _unlift_graph(model_, gm, graph_signature)
if "dynamo_flat_name_to_original_fqn" in model_.meta:
unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[
"dynamo_flat_name_to_original_fqn"
]
if "dynamo_compile_id" in model_.meta:
unlifted_gm.meta["dynamo_compile_id"] = model_.meta["dynamo_compile_id"]
# Disable amp as in aot_dispatch_autograd (https://github.com/pytorch/pytorch/pull/86515)
# In inference_compiler (fw_compiler_base), _recursive_joint_graph_passes will call into
# _sfdp_init() to register patterns.
# When fallback_random is set to True, the sdpa patterns will be traced during runtime.
# If amp is turned on, the traced FP32 patterns will have prims.convert_element_type which
# will be the same as the generated FP16 patterns.
disable_amp = torch._C._is_any_autocast_enabled()
context = (
torch._C._DisableAutocast if disable_amp else contextlib.nullcontext
)
with V.set_fake_mode(fake_mode), compiled_autograd._disable(), context():
return inference_compiler(unlifted_gm, example_inputs_)
with (
V.set_fake_mode(fake_mode),
torch._guards.tracing(tracing_context),
compiled_autograd._disable(),
functorch_config.patch(unlift_effect_tokens=True),
):
try:
return aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
inference_compiler=inference_compiler,
decompositions=decompositions,
partition_fn=partition_fn,
keep_inference_input_mutations=True,
cudagraphs=cudagraphs,
boxed_forward_device_index=forward_device,
ignore_shape_env=ignore_shape_env,
)(model_, example_inputs_)
except ShortenTraceback as e:
# We will also shorten the traceback inside dynamo.
# This is only useful if inductor is called directly with an FX graph.
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
def graph_returns_tuple(gm: GraphModule) -> bool:
"""True if a FX graph returns a tuple"""
if not isinstance(gm, GraphModule):
return True # can't check this, assume true
(rv,) = output_node(gm).args
if isinstance(rv, (list, tuple)):
return True
if (
isinstance(rv, torch.fx.node.Node)
and hasattr(rv.target, "_schema")
and len(rv.target._schema.returns) > 1
and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns)
):
# for graphs whose result is one node with multiple outputs
return True
return False
def make_graph_return_tuple(
gm: GraphModule,
inputs: Sequence[InputType],
compile_gm: Callable[..., Any],
) -> Callable[..., Any]:
"""
Mutate gm so it returns a tuple. This is only needed for graphs
not created by torchdynamo that return non-tuples.
"""
node = output_node(gm)
(rv,) = node.args
rv, spec = pytree.tree_flatten(rv)
with gm.graph.inserting_before(node):
gm.graph.output(rv)
gm.graph.erase_node(node)
assert graph_returns_tuple(gm)
compiled_fn = compile_gm(gm, inputs)
@functools.wraps(compiled_fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec)
return wrapper
def handle_dynamo_export_graph(
gm: GraphModule,
inputs: Sequence[InputType],
compile_gm: Callable[..., Any],
) -> Callable[..., Any]:
"""
`torch._dynamo.export` embeds pytrees in the FX graph codegen object,
convert that to a normal FX graph so inductor can compile it.
"""
codegen = gm.graph._codegen
gm.graph._codegen = torch.fx.graph.CodeGen()
gm.recompile()
compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs))
@functools.wraps(compiled_fn) # type: ignore[misc]
def wrapper(*args: Any) -> Any:
return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args)))
return wrapper
def _check_triton_bf16_support(graph: GraphLowering) -> None:
def warn_and_skip(device: Optional[torch.device]) -> Never:
from torch._dynamo.exc import SkipFrame
assert device is not None
device_interface = get_interface_for_device(device.type)
device_props = device_interface.get_device_properties(device)
warnings.warn(
f"{device_props.name} does not support bfloat16 compilation natively, skipping"
)
raise SkipFrame("BF16 is not supported")
for node in itertools.chain(graph.graph_inputs.values(), graph.graph_outputs):
if not isinstance(node, IRNode):
continue
device_type = get_device_type(node)
if (
not device_type
or not is_gpu(device_type)
or node.get_dtype() != torch.bfloat16
):
continue
# Print warning and skip frame if attempting to compile for bfloat16
# on device without hardware support for dtype
device_interface = get_interface_for_device(device_type)
if device_interface.is_bf16_supported(including_emulation=False):
return
warn_and_skip(node.get_device())
def _aoti_flatten_inputs(
gm: torch.fx.GraphModule,
args: Union[list[Any], tuple[Any, ...]],
kwargs: Optional[dict[str, Any]] = None,
*,
options: Optional[dict[str, Any]] = None,
) -> tuple[list[Any], dict[str, Any]]:
"""
Flatten the inputs to the graph module and return the flat inputs and options.
Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options.
"""
from .compile_fx import graph_returns_tuple
assert graph_returns_tuple(gm), (
"Graph output must be a tuple(). This is so that we can avoid "
"pytree processing of the outputs. Please change the module to "
"have tuple outputs."
)
# We will serialize the pytree info into the .so as constant strings
in_spec = None
out_spec = None
if isinstance(gm.graph._codegen, torch.fx.graph._PyTreeCodeGen):
codegen = gm.graph._codegen
gm.graph._codegen = torch.fx.graph.CodeGen()
gm.recompile()
if codegen.pytree_info.in_spec is not None:
in_spec = codegen.pytree_info.in_spec
if codegen.pytree_info.out_spec is not None:
out_spec = codegen.pytree_info.out_spec
else:
if hasattr(gm, "_in_spec"):
in_spec = gm._in_spec
if hasattr(gm, "_out_spec"):
out_spec = gm._out_spec
serialized_in_spec = pytree.treespec_dumps(in_spec) if in_spec is not None else ""
serialized_out_spec = (
pytree.treespec_dumps(out_spec) if out_spec is not None else ""
)
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
(args, kwargs or {})
)
if any(isinstance(x[1], torch.ScriptObject) for x in flat_args_with_path):
from torch._dynamo.exc import UserError, UserErrorType
raise UserError(
UserErrorType.INVALID_INPUT,
"TorchBind objects found in inputs. TorchBind object inputs are not supported in AOTInductor. "
"TorchBind objects can only be attributes.",
)
# Replace non-tensor (constant) inputs with Nones, since these are not being
# used anyways by the graph
flat_example_inputs = [
x[1] if isinstance(x[1], torch.Tensor) else None for x in flat_args_with_path
]
if in_spec is not None and received_spec != in_spec:
raise ValueError( # noqa: B904
"Trying to flatten user inputs with exported input tree spec: \n"
f"{in_spec}\n"
"but actually got inputs with tree spec of: \n"
f"{received_spec}"
)
options = (
{
"aot_inductor.serialized_in_spec": serialized_in_spec,
"aot_inductor.serialized_out_spec": serialized_out_spec,
}
if options is None
else {
**options,
"aot_inductor.serialized_in_spec": serialized_in_spec,
"aot_inductor.serialized_out_spec": serialized_out_spec,
}
)
return flat_example_inputs, options