Compare commits

...

2 Commits

Author SHA1 Message Date
1d3633accc workaround wrong maybe_subclass_meta.fw_metadata.static_input_indices 2025-11-01 17:45:00 -07:00
cba13b316a init 2025-11-01 15:42:39 -07:00
11 changed files with 171 additions and 22 deletions

View File

@ -179,6 +179,10 @@ def aot_stage1_graph_capture(
)
)
if maybe_subclass_meta is not None and maybe_subclass_meta.fw_metadata is not None:
# TODO: maybe_subclass_meta.fw_metadata.static_input_indices is wrong. This is a hack to workaround.
maybe_subclass_meta.fw_metadata.static_input_indices = aot_state.fw_metadata.static_input_indices
return AOTGraphCapture(
wrappers=wrappers,
graph_module=graph,
@ -2209,7 +2213,7 @@ def _aot_stage2b_compile_forward_or_inference(
# Set tracing context
if tracing_context := torch._guards.TracingContext.try_get():
tracing_context.fw_metadata = _get_inner_meta(
maybe_subclass_meta, fw_metadata
maybe_subclass_meta, fw_metadata # maybe_subclass_meta is wrong.
)
with TracingContext.report_output_strides() as fwd_output_strides:

View File

@ -332,6 +332,21 @@ def list_mode_options(
mode_options: dict[str, dict[str, bool]] = {
"default": {},
# lightweight backend
"light": {
"fallback_by_default": True,
"use_dce": False,
"allow_buffer_reuse": False,
"reorder_for_peak_memory": False,
"reorder_for_compute_comm_overlap": False,
"triton.reorder_for_reducing_graph_partitions": False,
"use_pre_grad_passes": False,
"use_joint_graph_passes": False,
"use_post_grad_passes": True,
"aten_distributed_optimizations.enable_overlap_scheduling": True,
"use_decomposition": False,
"triton.cudagraphs": True,
},
# enable cudagraphs
"reduce-overhead": {
"triton.cudagraphs": True,

View File

@ -508,6 +508,9 @@ def _recursive_pre_grad_passes(
log_pt2_compile_event=True,
dynamo_compile_column_us="pre_grad_pass_time_us",
):
if not config.use_pre_grad_passes:
return gm
add_passes = config.add_pre_grad_passes
remove_passes = config.remove_pre_grad_passes
for subgraph_name in _get_subgraph_names(gm):
@ -526,6 +529,9 @@ def _recursive_joint_graph_passes(
log_pt2_compile_event=True,
dynamo_compile_column_us="joint_graph_pass_time_us",
):
if not config.use_joint_graph_passes:
return
# invoke_subgraph already runs the _recursive_joint_graph_passes. In
# AOTAutograd, `run_joint_graph_passes_on_hops` partitions the
# invoke_subgraph HOP before calling the partitioner on the outer graph.
@ -544,6 +550,9 @@ def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) ->
log_pt2_compile_event=True,
dynamo_compile_column_us="post_grad_pass_time_us",
):
if not config.use_post_grad_passes:
return
for subgraph_name in _get_subgraph_names(gm):
subgraph = getattr(gm, subgraph_name)
_recursive_post_grad_passes(subgraph, is_inference)
@ -2634,6 +2643,9 @@ def _compile_fx_main(
decompositions if decompositions is not None else select_decomp_table()
)
if not config.use_decomposition:
decompositions = None
def fw_compiler_base(
gm: GraphModule,
example_inputs: Sequence[InputType],

View File

@ -546,6 +546,25 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
).upper() # type: ignore[assignment]
# Fall back to ATen for all ops by default, except for fx nodes with
# "compile_with_inductor" in node.meta["custom"]
fallback_by_default: bool = False
# Use dead code elimination
use_dce: bool = True
# Skip all decompositions when False
use_decomposition: bool = True
# Use fx graph passes
use_pre_grad_passes: bool = True
use_joint_graph_passes: bool = True
use_post_grad_passes: bool = True
# DEPRECATED. This setting is ignored.
autotune_fallback_to_aten = False
@ -1344,6 +1363,10 @@ class triton:
default=False,
)
# reorder nodes to minimize the number of graph partitions while
# not incurring large memory overhead
reorder_for_reducing_graph_partitions: bool = True
# assertions on the fast path
fast_path_cudagraph_asserts = False

View File

@ -2305,6 +2305,7 @@ class CUDAGraphTreeManager:
ModelType,
OutputType,
]:
print(f"mode:{mode}, num inputs:{len(inputs)}, num static inputs:{len(static_input_idxs)}, num constants:{len(constants)}, static_input_idxs:{static_input_idxs}")
id = self.new_func_id()
self.ids_to_stack_traces[id] = stack_traces
self.ids_to_funcs[id] = WrappedFunction(

View File

@ -7,7 +7,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
from torch._dynamo.utils import counters, get_metrics_context
from torch._inductor.utils import GraphPartitionMap, InputType
from torch._inductor.utils import GraphPartitionMap, InputType, CUDAGraphWrapperMetadata
from torch.utils._ordered_set import OrderedSet
from .utils import is_using_cudagraph_partition
@ -420,3 +420,52 @@ def get_partition_cudagraph_metadata(
partition_stack_traces,
partition_constants,
)
class CUDAGraphWrapper:
def __init__(
self,
runnable: Callable,
graph_pool: Optional[torch.cuda.graph_pool_handle] = None,
input_clone_indices: Optional[list[str]] = None,
):
self.runnable = runnable
self.graph_pool = graph_pool if graph_pool is not None else torch.cuda.graph_pool_handle()
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
self.input_clone_indices = input_clone_indices
self.input_buffers: Optional[list[torch.Tensor]] = None
self.output = None
self.has_warmup = False
def __call__(self, *args: tuple[torch.Tensor]):
if not self.has_warmup:
self.has_warmup = True
return self.runnable(*args)
if self.cudagraph is None:
self.cudagraph = torch.cuda.CUDAGraph()
if self.input_clone_indices:
self.input_buffers = [
arg.clone() if idx in self.input_clone_indices else arg for idx, arg in enumerate(args)
]
with torch.cuda.graph(self.cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
# TODO: use weak ref for output to reuse memory
self.output = self.runnable(*self.input_buffers)
if self.input_clone_indices:
for i in self.input_clone_indices:
self.input_buffers[i].copy_(args[i])
self.cudagraph.replay()
return self.output
def cudagraph_wrapper(fn: Callable, metadata: CUDAGraphWrapperMetadata) -> Callable:
# there should be static input idxs in the metadata
return CUDAGraphWrapper(fn)

View File

@ -114,6 +114,7 @@ from .utils import (
ValueWithLineMap,
)
from .virtualized import NullHandler, V
from torch.fx.passes.regional_inductor import _needs_inductor_compile
if TYPE_CHECKING:
@ -318,6 +319,13 @@ def mark_nodes_dislike_padding(
cur.meta["dislike_padding"] = True
def should_fallback_by_default(node: torch.fx.Node) -> bool:
if not config.fallback_by_default:
return False
return not _needs_inductor_compile(node)
class GraphLowering(torch.fx.Interpreter):
graph_outputs: list[ir.IRNode]
@ -1626,7 +1634,7 @@ class GraphLowering(torch.fx.Interpreter):
fallback_node_due_to_unsupported_type(n)
or CompilerBisector.disable_subsystem(
"inductor", "lowerings", lambda: repr(n)
)
) or should_fallback_by_default(n)
)
):
debug("fallback_handler")

View File

@ -254,6 +254,35 @@ def cudagraph_post_compile(
)
def get_cudagraph_metadata_for_all_partitions(
compiled_graph: CompiledFxGraph,
constants: dict[str, torch.Tensor],
) -> list[CudagraphMetadata]:
assert compiled_graph.current_callable is not None
assert compiled_graph.recursively_apply_fns is not None
static_input_idxs = OrderedSet(compiled_graph.fx_kwargs["static_input_idxs"] or ())
mutated_input_idxs = compiled_graph.mutated_input_idxs
graph_metadata = CudagraphMetadata(
compiled_graph.cudagraph_info.placeholders,
static_input_idxs,
mutated_input_idxs,
compiled_graph.cudagraph_info.stack_traces,
constants,
)
res = []
for partition_map in compiled_graph.partition_maps:
partition_metadata = get_partition_cudagraph_metadata(
partition_map,
graph_metadata,
)
res.append(partition_metadata)
return res
def cudagraph_partition_post_compile(
example_inputs: Sequence[InputType],
compiled_graph: CompiledFxGraph,
@ -283,35 +312,27 @@ def cudagraph_partition_post_compile(
from .compile_fx import cudagraphify
assert compiled_graph.current_callable is not None
assert compiled_graph.recursively_apply_fns is not None
is_inference = compiled_graph.fx_kwargs["is_inference"]
is_backward = compiled_graph.fx_kwargs["is_backward"]
static_input_idxs = OrderedSet(compiled_graph.fx_kwargs["static_input_idxs"] or ())
mutated_input_idxs = compiled_graph.mutated_input_idxs
device_index = next(iter(compiled_graph.device_idxs))
graph_metadata = CudagraphMetadata(
compiled_graph.cudagraph_info.placeholders,
static_input_idxs,
mutated_input_idxs,
compiled_graph.cudagraph_info.stack_traces,
constants,
)
prepare_cudagraph_post_compile(
compiled_graph, example_inputs, boxed_forward_device_index
)
partition_metadatas = get_cudagraph_metadata_for_all_partitions(
compiled_graph, constants
)
# cudagraphify each partition function, assuming every graph partition function
# is cudagraphable. Non-cudagraphable ops (e.g., cpu ops) are inlined into
# `call` function and not included in partition functions.
cudagraphify_fns = []
for partition_map in compiled_graph.partition_maps:
partition_metadata = get_partition_cudagraph_metadata(
partition_map,
graph_metadata,
)
# TODO: remove this. only for debug
for i, partition_metadata in enumerate(partition_metadatas):
# print("mutated_input_idxs: ", partition_metadata.mutated_input_idxs)
print(f"post_compile. is_backward:{is_backward}, i:{i}, static_input_idxs: {partition_metadata.static_input_idxs}")
cudagraphify_fn = partial(
cudagraphify,
@ -643,8 +664,12 @@ class CompiledFxGraph(OutputCode):
assert self.recursively_apply_fns is not None
assert self.compiled_fn_runner is not None
num_partitions = len(self.compiled_fn_runner.partitions)
# 1: add static tensor indices to metadata
partition_metadatas = get_cudagraph_metadata_for_all_partitions(self, constants.unwrap(self))
wrapper_metadatas = [
CUDAGraphWrapperMetadata(num_partitions, i)
CUDAGraphWrapperMetadata(num_partitions, i, partition_metadatas[i].static_tensor_indices, partition_metadatas[i].mutated_input_idxs)
for i in range(num_partitions)
]
customized_wrapper = _unstable_customized_partition_wrapper.wrapper

View File

@ -2713,6 +2713,7 @@ class Scheduler:
if (
torch._inductor.config.graph_partition
and torch._inductor.config.triton.cudagraphs
and torch._inductor.config.triton.reorder_for_reducing_graph_partitions
):
self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes)
self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes)
@ -3160,6 +3161,9 @@ class Scheduler:
"""
Remove any nodes without users
"""
if not config.use_dce:
return
# self.nodes is in topological order, so by iterating in reverse order
# we have visited (and potentially removed) all users before visiting a
# given node.

View File

@ -3861,6 +3861,14 @@ class CUDAGraphWrapperMetadata:
# Index of the current partition.
partition_index: int
# index of static inputs that do not need to be copied
static_input_idxs: list[int]
# index of mutated inputs. If a mutated input is NOT a static
# input, it needs to be copied after graph.replay().
# TODO: more doc
mutated_input_idxs: list[int]
PartitionFnType = Callable[..., Any]
CUDAGraphWrapperType = Callable[

View File

@ -112,7 +112,7 @@ def _compile_submod(gm, prefix):
return gm
def _needs_inductor_compile(node):
def _needs_inductor_compile(node: torch.fx.Node) -> bool:
return (
node.op not in ("placeholder", "output")
and hasattr(node, "meta")