mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Almost there! Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164913 Approved by: https://github.com/oulgen
300 lines
10 KiB
Python
300 lines
10 KiB
Python
"""
|
|
This module implements CUDA graphs support for TorchDynamo backends.
|
|
|
|
CUDA graphs allow for capturing and replaying GPU operations, which can significantly
|
|
reduce CPU overhead in GPU-accelerated PyTorch models. This module provides:
|
|
|
|
- CUDA graph creation and management for both forward and backward passes
|
|
- Input mutation detection and handling
|
|
- Device compatibility checking
|
|
- Stack trace management for debugging
|
|
- Integration with TorchInductor's cudagraph trees
|
|
|
|
The backend supports two main modes:
|
|
1. cudagraphs: Full CUDA graph support with both forward and backward pass optimization
|
|
2. cudagraphs_inner: Lower-level CUDA graph implementation used for benchmarking
|
|
|
|
Key components:
|
|
- CudagraphsBackend: Main backend class for CUDA graph integration
|
|
- Mutation detection utilities to ensure graph safety
|
|
- Device mapping and compatibility checks
|
|
- Stack trace collection for debugging
|
|
"""
|
|
|
|
import functools
|
|
from collections import defaultdict
|
|
from collections.abc import Sequence
|
|
from typing import Any, Callable, Optional
|
|
|
|
import torch
|
|
import torch.fx
|
|
from torch._dynamo import config
|
|
from torch._dynamo.backends.common import aot_autograd
|
|
from torch._dynamo.backends.debugging import boxed_nop
|
|
from torch._inductor.cudagraph_utils import (
|
|
BoxedDeviceIndex,
|
|
check_multiple_devices_or_any_cpu_nodes,
|
|
format_default_skip_message,
|
|
get_mutation_stack_trace,
|
|
get_placeholder_info,
|
|
log_cudagraph_skip_and_bump_counter,
|
|
)
|
|
from torch._inductor.utils import (
|
|
BoxedBool,
|
|
count_tangents,
|
|
get_first_incompatible_cudagraph_node,
|
|
num_fw_fixed_arguments,
|
|
output_node,
|
|
)
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
|
|
from .registry import register_backend
|
|
|
|
|
|
def find_input_mutations(g: torch.fx.Graph) -> set[int]:
|
|
def meta_fk(meta: dict[str, Any]) -> Any:
|
|
return meta["val"] if "val" in meta else meta["fake_result"]
|
|
|
|
inputs = defaultdict(set)
|
|
input_idx = 0
|
|
mutated_inputs = set()
|
|
for n in g.nodes:
|
|
if n.op == "placeholder":
|
|
if isinstance(meta_fk(n.meta), torch.Tensor):
|
|
inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
|
|
input_idx += 1
|
|
elif n.op == "call_function":
|
|
if not hasattr(n.target, "_schema"):
|
|
continue
|
|
|
|
schema = n.target._schema
|
|
for i, arg in enumerate(schema.arguments):
|
|
if i < len(n.args):
|
|
argument = n.args[i]
|
|
else:
|
|
if arg.name not in n.kwargs:
|
|
continue
|
|
argument = n.kwargs[arg.name]
|
|
mut_arg = False
|
|
if arg.alias_info:
|
|
if arg.alias_info.is_write:
|
|
mut_arg = True
|
|
if mut_arg:
|
|
# TODO: not correct for args that contain tensors in a struct
|
|
# like list
|
|
mutated_inputs |= inputs[
|
|
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
|
|
]
|
|
|
|
# TODO: error on unrecognized nodes
|
|
return mutated_inputs
|
|
|
|
|
|
def get_device_node_mapping(
|
|
gm: torch.fx.GraphModule,
|
|
) -> dict[torch.device, torch.fx.Node]:
|
|
device_node_mapping: dict[torch.device, torch.fx.Node] = {}
|
|
for n in gm.graph.nodes:
|
|
t = n.meta.get("val", None)
|
|
if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
|
|
device_node_mapping[t.device] = n
|
|
return device_node_mapping
|
|
|
|
|
|
def check_for_mutation_ignore_cuda_graph_managed_tensor(
|
|
aot_model: torch.fx.GraphModule, num_fixed: int
|
|
) -> Optional[str]:
|
|
mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
|
|
if not mutation_indices:
|
|
return None
|
|
|
|
placeholders = get_placeholder_info(aot_model.graph)
|
|
return get_mutation_stack_trace(placeholders, mutation_indices)
|
|
|
|
|
|
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed: int) -> Optional[str]:
|
|
if not config.cudagraph_backend_support_input_mutation:
|
|
if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor(
|
|
aot_model, num_fixed
|
|
):
|
|
return mut_skip
|
|
|
|
if skip := check_multiple_devices_or_any_cpu_nodes(
|
|
get_device_node_mapping(aot_model)
|
|
):
|
|
return skip
|
|
|
|
if node := get_first_incompatible_cudagraph_node(aot_model):
|
|
return format_default_skip_message(f"incompatible op ({node.name})")
|
|
|
|
return None
|
|
|
|
|
|
def get_device_index(gm: torch.fx.GraphModule) -> int:
|
|
device = next(iter(get_device_node_mapping(gm)))
|
|
assert device.type == "cuda"
|
|
return device.index
|
|
|
|
|
|
def get_stack_traces(gm: torch.fx.GraphModule) -> list[Optional[str]]:
|
|
output = output_node(gm)
|
|
assert len(output.args) == 1
|
|
args = output.args[0]
|
|
if not hasattr(args, "__iter__"):
|
|
return []
|
|
return [
|
|
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
|
|
for arg in args # type: ignore[union-attr]
|
|
]
|
|
|
|
|
|
def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any]) -> Any:
|
|
from torch._inductor.cudagraph_trees import cudagraphify_impl
|
|
|
|
do_cudagraphs = BoxedBool(True)
|
|
boxed_device_index = BoxedDeviceIndex(None)
|
|
|
|
def forward_cudagraphs(
|
|
aot_model: torch.fx.GraphModule,
|
|
aot_inputs: list[Any],
|
|
is_inference: bool = False,
|
|
) -> Any:
|
|
interp = boxed_nop(aot_model, aot_inputs)
|
|
fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
|
|
if skip_msg := check_for_skip(aot_model, fixed):
|
|
BoxedBool.disable(do_cudagraphs)
|
|
log_cudagraph_skip_and_bump_counter(
|
|
f"skipping cudagraphs due to {skip_msg}"
|
|
)
|
|
return interp
|
|
|
|
boxed_device_index.set(get_device_index(aot_model))
|
|
out = cudagraphify_impl(
|
|
interp,
|
|
aot_inputs,
|
|
range(fixed),
|
|
device_index=boxed_device_index.value,
|
|
is_backward=False,
|
|
is_inference=False, # Q: should forward is_inference here?
|
|
stack_traces=get_stack_traces(aot_model),
|
|
placeholders=get_placeholder_info(aot_model.graph),
|
|
mutated_input_idxs=find_input_mutations(aot_model.graph),
|
|
)
|
|
out._boxed_call = True # type: ignore[attr-defined]
|
|
return out
|
|
|
|
def backward_cudagraphs(
|
|
aot_model: torch.fx.GraphModule, aot_inputs: list[Any]
|
|
) -> Any:
|
|
interp = boxed_nop(aot_model, aot_inputs)
|
|
if not do_cudagraphs:
|
|
return aot_model
|
|
|
|
fixed = count_tangents(aot_model)
|
|
if skip_msg := check_for_skip(aot_model, fixed):
|
|
log_cudagraph_skip_and_bump_counter(
|
|
f"skipping cudagraphs due to {skip_msg}"
|
|
)
|
|
|
|
# See [Backward Generation Handling]
|
|
device_idx = boxed_device_index.value
|
|
if device_idx is None:
|
|
device_idx = 0 # Default to device 0 if not set
|
|
manager = torch._inductor.cudagraph_trees.get_manager(
|
|
device_idx, create_if_none_exists=False
|
|
)
|
|
assert manager is not None
|
|
|
|
def fn(inputs: list[Any]) -> Any:
|
|
# pyrefly: ignore # missing-attribute
|
|
manager.set_to_running_backward()
|
|
return aot_model(inputs)
|
|
|
|
fn._boxed_call = True # type: ignore[attr-defined]
|
|
return fn
|
|
|
|
out = cudagraphify_impl(
|
|
interp,
|
|
aot_inputs,
|
|
range(fixed),
|
|
device_index=get_device_index(aot_model),
|
|
is_backward=True,
|
|
is_inference=False,
|
|
stack_traces=get_stack_traces(aot_model),
|
|
placeholders=get_placeholder_info(aot_model.graph),
|
|
mutated_input_idxs=find_input_mutations(aot_model.graph),
|
|
)
|
|
out._boxed_call = True # type: ignore[attr-defined]
|
|
return out
|
|
|
|
aot_cudagraphs = aot_autograd(
|
|
fw_compiler=forward_cudagraphs,
|
|
bw_compiler=backward_cudagraphs,
|
|
inference_compiler=functools.partial(forward_cudagraphs, is_inference=True),
|
|
keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
|
|
)
|
|
return aot_cudagraphs(dynamo_model, dynamo_inputs)
|
|
|
|
|
|
class CudagraphsBackend:
|
|
compiler_name = "cudagraphs"
|
|
|
|
@staticmethod
|
|
def reset() -> None:
|
|
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
|
|
|
|
reset_cudagraph_trees()
|
|
|
|
@staticmethod
|
|
def __call__(model: torch.fx.GraphModule, inputs: Sequence[Any]) -> Any:
|
|
return cudagraphs(model, inputs)
|
|
|
|
|
|
# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
|
|
# for debugging and can serve as a perf baseline.
|
|
register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
|
|
|
|
|
|
def cudagraphs_inner(
|
|
model: Callable[..., Any],
|
|
inputs: Sequence[Any],
|
|
copy_outputs: bool = True,
|
|
copy_inputs: bool = True,
|
|
) -> Callable[..., Sequence[Any]]:
|
|
"""This isn't registered as a backend, but is used in some benchmarks"""
|
|
assert isinstance(inputs, (list, tuple))
|
|
if copy_inputs:
|
|
static_inputs = [torch.zeros_like(x) for x in inputs]
|
|
else:
|
|
static_inputs = list(inputs)
|
|
|
|
# warmup
|
|
torch.cuda.synchronize()
|
|
stream = torch.cuda.Stream()
|
|
stream.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(stream):
|
|
model(*inputs)
|
|
stream.synchronize()
|
|
torch.cuda.current_stream().wait_stream(stream)
|
|
torch.cuda.synchronize()
|
|
|
|
# record
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph, stream=stream):
|
|
static_outputs = model(*static_inputs)
|
|
if not isinstance(static_outputs, (list, tuple)):
|
|
static_outputs = (static_outputs,)
|
|
|
|
def run(*new_inputs: Any) -> Sequence[Any]:
|
|
assert len(static_inputs) == len(new_inputs)
|
|
if copy_inputs:
|
|
for dst, src in zip(static_inputs, new_inputs):
|
|
dst.copy_(src)
|
|
graph.replay()
|
|
if copy_outputs:
|
|
return [x.clone() for x in static_outputs]
|
|
else:
|
|
return static_outputs
|
|
|
|
return run
|