Files
pytorch/torch/_dynamo/backends/cudagraphs.py
Maggie Moss c855f8632e Pyrefly suppressions 7/n (#164913)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Almost there!

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
 INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164913
Approved by: https://github.com/oulgen
2025-10-08 07:27:17 +00:00

300 lines
10 KiB
Python

"""
This module implements CUDA graphs support for TorchDynamo backends.
CUDA graphs allow for capturing and replaying GPU operations, which can significantly
reduce CPU overhead in GPU-accelerated PyTorch models. This module provides:
- CUDA graph creation and management for both forward and backward passes
- Input mutation detection and handling
- Device compatibility checking
- Stack trace management for debugging
- Integration with TorchInductor's cudagraph trees
The backend supports two main modes:
1. cudagraphs: Full CUDA graph support with both forward and backward pass optimization
2. cudagraphs_inner: Lower-level CUDA graph implementation used for benchmarking
Key components:
- CudagraphsBackend: Main backend class for CUDA graph integration
- Mutation detection utilities to ensure graph safety
- Device mapping and compatibility checks
- Stack trace collection for debugging
"""
import functools
from collections import defaultdict
from collections.abc import Sequence
from typing import Any, Callable, Optional
import torch
import torch.fx
from torch._dynamo import config
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.backends.debugging import boxed_nop
from torch._inductor.cudagraph_utils import (
BoxedDeviceIndex,
check_multiple_devices_or_any_cpu_nodes,
format_default_skip_message,
get_mutation_stack_trace,
get_placeholder_info,
log_cudagraph_skip_and_bump_counter,
)
from torch._inductor.utils import (
BoxedBool,
count_tangents,
get_first_incompatible_cudagraph_node,
num_fw_fixed_arguments,
output_node,
)
from torch.multiprocessing.reductions import StorageWeakRef
from .registry import register_backend
def find_input_mutations(g: torch.fx.Graph) -> set[int]:
def meta_fk(meta: dict[str, Any]) -> Any:
return meta["val"] if "val" in meta else meta["fake_result"]
inputs = defaultdict(set)
input_idx = 0
mutated_inputs = set()
for n in g.nodes:
if n.op == "placeholder":
if isinstance(meta_fk(n.meta), torch.Tensor):
inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
input_idx += 1
elif n.op == "call_function":
if not hasattr(n.target, "_schema"):
continue
schema = n.target._schema
for i, arg in enumerate(schema.arguments):
if i < len(n.args):
argument = n.args[i]
else:
if arg.name not in n.kwargs:
continue
argument = n.kwargs[arg.name]
mut_arg = False
if arg.alias_info:
if arg.alias_info.is_write:
mut_arg = True
if mut_arg:
# TODO: not correct for args that contain tensors in a struct
# like list
mutated_inputs |= inputs[
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
]
# TODO: error on unrecognized nodes
return mutated_inputs
def get_device_node_mapping(
gm: torch.fx.GraphModule,
) -> dict[torch.device, torch.fx.Node]:
device_node_mapping: dict[torch.device, torch.fx.Node] = {}
for n in gm.graph.nodes:
t = n.meta.get("val", None)
if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
device_node_mapping[t.device] = n
return device_node_mapping
def check_for_mutation_ignore_cuda_graph_managed_tensor(
aot_model: torch.fx.GraphModule, num_fixed: int
) -> Optional[str]:
mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
if not mutation_indices:
return None
placeholders = get_placeholder_info(aot_model.graph)
return get_mutation_stack_trace(placeholders, mutation_indices)
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed: int) -> Optional[str]:
if not config.cudagraph_backend_support_input_mutation:
if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor(
aot_model, num_fixed
):
return mut_skip
if skip := check_multiple_devices_or_any_cpu_nodes(
get_device_node_mapping(aot_model)
):
return skip
if node := get_first_incompatible_cudagraph_node(aot_model):
return format_default_skip_message(f"incompatible op ({node.name})")
return None
def get_device_index(gm: torch.fx.GraphModule) -> int:
device = next(iter(get_device_node_mapping(gm)))
assert device.type == "cuda"
return device.index
def get_stack_traces(gm: torch.fx.GraphModule) -> list[Optional[str]]:
output = output_node(gm)
assert len(output.args) == 1
args = output.args[0]
if not hasattr(args, "__iter__"):
return []
return [
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
for arg in args # type: ignore[union-attr]
]
def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any]) -> Any:
from torch._inductor.cudagraph_trees import cudagraphify_impl
do_cudagraphs = BoxedBool(True)
boxed_device_index = BoxedDeviceIndex(None)
def forward_cudagraphs(
aot_model: torch.fx.GraphModule,
aot_inputs: list[Any],
is_inference: bool = False,
) -> Any:
interp = boxed_nop(aot_model, aot_inputs)
fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
if skip_msg := check_for_skip(aot_model, fixed):
BoxedBool.disable(do_cudagraphs)
log_cudagraph_skip_and_bump_counter(
f"skipping cudagraphs due to {skip_msg}"
)
return interp
boxed_device_index.set(get_device_index(aot_model))
out = cudagraphify_impl(
interp,
aot_inputs,
range(fixed),
device_index=boxed_device_index.value,
is_backward=False,
is_inference=False, # Q: should forward is_inference here?
stack_traces=get_stack_traces(aot_model),
placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True # type: ignore[attr-defined]
return out
def backward_cudagraphs(
aot_model: torch.fx.GraphModule, aot_inputs: list[Any]
) -> Any:
interp = boxed_nop(aot_model, aot_inputs)
if not do_cudagraphs:
return aot_model
fixed = count_tangents(aot_model)
if skip_msg := check_for_skip(aot_model, fixed):
log_cudagraph_skip_and_bump_counter(
f"skipping cudagraphs due to {skip_msg}"
)
# See [Backward Generation Handling]
device_idx = boxed_device_index.value
if device_idx is None:
device_idx = 0 # Default to device 0 if not set
manager = torch._inductor.cudagraph_trees.get_manager(
device_idx, create_if_none_exists=False
)
assert manager is not None
def fn(inputs: list[Any]) -> Any:
# pyrefly: ignore # missing-attribute
manager.set_to_running_backward()
return aot_model(inputs)
fn._boxed_call = True # type: ignore[attr-defined]
return fn
out = cudagraphify_impl(
interp,
aot_inputs,
range(fixed),
device_index=get_device_index(aot_model),
is_backward=True,
is_inference=False,
stack_traces=get_stack_traces(aot_model),
placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True # type: ignore[attr-defined]
return out
aot_cudagraphs = aot_autograd(
fw_compiler=forward_cudagraphs,
bw_compiler=backward_cudagraphs,
inference_compiler=functools.partial(forward_cudagraphs, is_inference=True),
keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
)
return aot_cudagraphs(dynamo_model, dynamo_inputs)
class CudagraphsBackend:
compiler_name = "cudagraphs"
@staticmethod
def reset() -> None:
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
reset_cudagraph_trees()
@staticmethod
def __call__(model: torch.fx.GraphModule, inputs: Sequence[Any]) -> Any:
return cudagraphs(model, inputs)
# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
# for debugging and can serve as a perf baseline.
register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
def cudagraphs_inner(
model: Callable[..., Any],
inputs: Sequence[Any],
copy_outputs: bool = True,
copy_inputs: bool = True,
) -> Callable[..., Sequence[Any]]:
"""This isn't registered as a backend, but is used in some benchmarks"""
assert isinstance(inputs, (list, tuple))
if copy_inputs:
static_inputs = [torch.zeros_like(x) for x in inputs]
else:
static_inputs = list(inputs)
# warmup
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
model(*inputs)
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()
# record
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
static_outputs = model(*static_inputs)
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
def run(*new_inputs: Any) -> Sequence[Any]:
assert len(static_inputs) == len(new_inputs)
if copy_inputs:
for dst, src in zip(static_inputs, new_inputs):
dst.copy_(src)
graph.replay()
if copy_outputs:
return [x.clone() for x in static_outputs]
else:
return static_outputs
return run