mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145137 Approved by: https://github.com/bobrenjc93
336 lines
11 KiB
Python
336 lines
11 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
from enum import Enum
|
|
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
|
|
|
import torch
|
|
from torch._dynamo.utils import counters
|
|
from torch._inductor.utils import InputType
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Sequence
|
|
|
|
|
|
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
|
static_inputs_log = torch._logging.getArtifactLogger(
|
|
__name__, "cudagraph_static_inputs"
|
|
)
|
|
|
|
|
|
OutputType = list[Optional[Union[int, torch.Tensor]]]
|
|
ModelType = Callable[[list[InputType]], OutputType]
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class FunctionID:
|
|
"Unique counter of a function wrapped in cudagraphify_impl"
|
|
id: int
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class PlaceholderInfo:
|
|
"""
|
|
A serializable version of torch.fx.Node that contains information
|
|
pertinent to placeholder stack traces. We use these in logging and error messages
|
|
related to cudagraphs, and will cache these results.
|
|
"""
|
|
|
|
name: str
|
|
stack_trace: Optional[str]
|
|
# This field is recursive, but never cyclic (since a node never uses itself)
|
|
users: list[PlaceholderInfo]
|
|
mutating_use_stack_trace: Optional[str]
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class WrappedFunction:
|
|
"""
|
|
Represents a function that you want to record for CUDA graph replay,
|
|
with a little more metadata so we can identify if we have an applicable
|
|
CUDA graph in our CUDA graph tree for it.
|
|
"""
|
|
|
|
model: Callable[..., Any]
|
|
static_input_idxs: Sequence[int]
|
|
id: FunctionID
|
|
constants: tuple[torch.Tensor, ...]
|
|
placeholders: Sequence[PlaceholderInfo]
|
|
mutated_input_idxs: Sequence[int]
|
|
|
|
|
|
def get_mutating_use_stack_trace_from_node(
|
|
placeholder_node: torch.fx.Node,
|
|
) -> Optional[str]:
|
|
# reinplaced uses might have a single, non-copy_ use
|
|
if len(placeholder_node.users) == 1:
|
|
return next(iter(placeholder_node.users)).meta.get("stack_trace", None)
|
|
|
|
for use in placeholder_node.users:
|
|
if use.target == torch.ops.aten.copy_.default:
|
|
if stack_trace := use.meta.get("stack_trace", None):
|
|
return stack_trace
|
|
|
|
return None
|
|
|
|
|
|
def get_mutating_use_stack_trace(placeholder_info: PlaceholderInfo) -> Optional[str]:
|
|
return placeholder_info.mutating_use_stack_trace
|
|
|
|
|
|
def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo:
|
|
name = placeholder_node.name
|
|
stack_trace = placeholder_node.meta.get("stack_trace", None)
|
|
users = []
|
|
mutating_use_stack_trace = None
|
|
# Only recurse to users once, since we only care about user's stack traces
|
|
if placeholder_node.op == "placeholder":
|
|
users = [to_placeholder_info(i) for i in placeholder_node.users]
|
|
mutating_use_stack_trace = get_mutating_use_stack_trace_from_node(
|
|
placeholder_node
|
|
)
|
|
|
|
return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace)
|
|
|
|
|
|
def get_placeholder_info(graph: torch.fx.Graph) -> list[PlaceholderInfo]:
|
|
return [
|
|
to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder"
|
|
]
|
|
|
|
|
|
def format_default_skip_message(reason: str) -> str:
|
|
return f"skipping cudagraphs due to {reason}"
|
|
|
|
|
|
def get_mutation_stack_trace(
|
|
placeholders: Sequence[PlaceholderInfo], mutation_indices: Sequence[int]
|
|
) -> str:
|
|
stack_trace: Optional[str] = ""
|
|
|
|
for idx in mutation_indices:
|
|
placeholder = placeholders[idx]
|
|
if stack_trace := get_mutating_use_stack_trace(placeholder):
|
|
break
|
|
|
|
msg = format_default_skip_message(
|
|
f"mutated inputs ({len(mutation_indices)} instances)"
|
|
)
|
|
if stack_trace:
|
|
return f"{msg}. Found from : \n {stack_trace}"
|
|
|
|
return msg
|
|
|
|
|
|
def check_for_mutation(
|
|
func: WrappedFunction,
|
|
inputs: list[InputType],
|
|
is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool],
|
|
) -> Optional[str]:
|
|
# doesnt work for non-trees because the warmup run would apply mutation twice
|
|
if torch._inductor.config.triton.cudagraph_trees:
|
|
# checking if mutation is only on parameters/static inputs
|
|
mutation_indices: Sequence[int] = [
|
|
idx
|
|
for idx in func.mutated_input_idxs
|
|
if not (
|
|
idx in func.static_input_idxs
|
|
or is_cuda_graph_recorded_tensor(inputs[idx]) # type: ignore[arg-type]
|
|
)
|
|
]
|
|
else:
|
|
mutation_indices = func.mutated_input_idxs
|
|
|
|
static_inputs_log.debug(
|
|
"check mutation static input indices: %s", func.static_input_idxs
|
|
)
|
|
static_inputs_log.debug("check mutation mutation indices: %s", mutation_indices)
|
|
|
|
return (
|
|
get_mutation_stack_trace(func.placeholders, mutation_indices)
|
|
if mutation_indices
|
|
else None
|
|
)
|
|
|
|
|
|
def _get_use_stack_trace(node) -> Optional[str]:
|
|
for use in node.users:
|
|
if stack_trace := use.meta.get("stack_trace", None):
|
|
return stack_trace
|
|
return None
|
|
|
|
|
|
def check_multiple_devices_or_any_cpu_nodes(
|
|
device_node_mapping: dict[torch.device, torch.fx.Node]
|
|
) -> Optional[str]:
|
|
if cpu_node := device_node_mapping.get(torch.device("cpu")):
|
|
msg = f"cpu device ({cpu_node.name})"
|
|
if stack_trace := _get_use_stack_trace(cpu_node):
|
|
return format_default_skip_message(f"{msg}. Found from : \n {stack_trace}")
|
|
|
|
return format_default_skip_message(msg)
|
|
|
|
if (
|
|
len(device_node_mapping) == 1
|
|
and next(iter(device_node_mapping.keys())).type == "cuda"
|
|
):
|
|
return None
|
|
|
|
keys_repr = (repr(key) for key in device_node_mapping.keys())
|
|
return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}")
|
|
|
|
|
|
def check_lowering_disable_cudagraph(
|
|
device_node_mapping: dict[torch.device, torch.fx.Node]
|
|
):
|
|
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
|
|
|
|
|
|
def log_cudagraph_skip_and_bump_counter(msg):
|
|
perf_hint_log.warning(msg)
|
|
counters["inductor"]["cudagraph_skips"] += 1
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class BoxedDeviceIndex:
|
|
value: Optional[int]
|
|
|
|
def set(self, device_idx: Optional[int]):
|
|
assert device_idx is None or isinstance(device_idx, int)
|
|
self.value = device_idx
|
|
|
|
|
|
def check_for_mutation_ignore_cuda_graph_managed_tensor(
|
|
gm: torch.fx.GraphModule, compiled_graph, static_input_idxs: Sequence[int]
|
|
) -> Optional[str]:
|
|
default_msg = format_default_skip_message("mutated inputs")
|
|
|
|
# doesnt work for non-trees because the warmup run would apply mutation twice
|
|
if torch._inductor.config.triton.cudagraph_trees:
|
|
unique_idxs = OrderedSet(static_input_idxs)
|
|
# checking if mutation is only on parameters/static inputs
|
|
mutation_indices = [
|
|
idx for idx in compiled_graph.mutated_input_idxs if idx not in unique_idxs
|
|
]
|
|
has_mutation = len(mutation_indices) != 0
|
|
if not has_mutation:
|
|
return None
|
|
placeholders = get_placeholder_info(gm.graph)
|
|
return get_mutation_stack_trace(placeholders, mutation_indices)
|
|
|
|
else:
|
|
has_mutation = len(compiled_graph.mutated_inputs) != 0
|
|
return None if not has_mutation else default_msg
|
|
|
|
|
|
def get_placeholder_stack_trace(placeholder: PlaceholderInfo) -> Optional[str]:
|
|
"""
|
|
Gets the first non-empty stack trace of a placeholder or its users.
|
|
"""
|
|
if placeholder.stack_trace:
|
|
return placeholder.stack_trace
|
|
|
|
for user in placeholder.users:
|
|
if user.stack_trace:
|
|
return user.stack_trace
|
|
|
|
return None
|
|
|
|
|
|
class CheckInvariantStatus(Enum):
|
|
# Check invariant succeeded
|
|
SUCCESS = 1
|
|
|
|
# Previously managed data pointers are not stable
|
|
CudagraphManagedIdxMismatch = 2
|
|
|
|
# Static tensor input addresses are not stable
|
|
StaticInputIdxMismatch = 3
|
|
|
|
# Expected dead indices before graph are live
|
|
ExpectedDeadIndicesBeforeGraphMismatch = 4
|
|
|
|
def __str__(self) -> str:
|
|
if self.name == "CudagraphManagedIdxMismatch":
|
|
return "cudagraph managed tensor data pointer changed"
|
|
elif self.name == "StaticInputIdxMismatch":
|
|
return "static input data pointer changed"
|
|
elif self.name == "ExpectedDeadIndicesBeforeGraphMismatch":
|
|
return "expected dead indices before graph are live"
|
|
else:
|
|
return f"{self.name}: {self.value}"
|
|
|
|
|
|
def log_data_ptr_mismatch(
|
|
placeholders: Sequence[PlaceholderInfo],
|
|
inputs: list[InputType],
|
|
recorded_data_ptr: Sequence[Optional[int]],
|
|
target_idxs: Sequence[int],
|
|
mismatch: CheckInvariantStatus,
|
|
) -> str:
|
|
"""
|
|
Logs the mismatch between input data pointers and recorded data pointers.
|
|
This checks only idxs in target_idxs.
|
|
"""
|
|
assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(
|
|
placeholders
|
|
), "length mismatch between inputs, recorded_data_ptr, and placeholders"
|
|
|
|
t_tensors = [inputs[i] for i in target_idxs]
|
|
t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs]
|
|
error_msg = f"{mismatch}.\n"
|
|
for i, (tensor, data_ptr) in enumerate(zip(t_tensors, t_data_ptrs)):
|
|
assert isinstance(tensor, torch.Tensor)
|
|
index = target_idxs[i]
|
|
if tensor.data_ptr() != data_ptr:
|
|
placeholder = placeholders[index]
|
|
error_msg = (
|
|
f"{error_msg}input name: {placeholder.name}. "
|
|
f"data pointer changed from {data_ptr} to {tensor.data_ptr()}. "
|
|
f"input stack trace: {get_placeholder_stack_trace(placeholder)}\n"
|
|
)
|
|
return error_msg
|
|
|
|
|
|
def maybe_warning_due_to_dynamic_shape(
|
|
fn_cache: dict[tuple[int, ...], Callable[..., Any]],
|
|
new_int_key: Any,
|
|
) -> bool:
|
|
num_cudagraphs = len(fn_cache.keys()) + 1
|
|
|
|
def warn_msg():
|
|
return (
|
|
"CUDAGraph supports dynamic shapes by recording a new graph for each "
|
|
"distinct input size. Recording too many CUDAGraphs may lead to "
|
|
f"extra overhead. We have observed {num_cudagraphs} distinct sizes. "
|
|
"Please consider the following options for better performance: "
|
|
"a) padding inputs to a few fixed number of shapes; or b) set "
|
|
"torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. "
|
|
"Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None "
|
|
"to silence this warning."
|
|
)
|
|
|
|
if (
|
|
torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit
|
|
and num_cudagraphs
|
|
> torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit
|
|
):
|
|
perf_hint_log.warning(warn_msg())
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class CudagraphCachedInfo:
|
|
"""
|
|
Info needed to realign inputs
|
|
"""
|
|
|
|
placeholders: Sequence[PlaceholderInfo]
|
|
stack_traces: list[Optional[str]]
|
|
cudagraph_fail_reasons: list[str]
|