Fix types in graphs.py (#158192)

Added type annotations for torch/cuda/graphs.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158192
Approved by: https://github.com/oulgen
This commit is contained in:
Aaron Orenstein
2025-07-13 11:51:16 -07:00
committed by PyTorch MergeBot
parent 011026205a
commit 250ae2531c
6 changed files with 145 additions and 70 deletions

View File

@ -40,6 +40,7 @@ from torch._C import (
)
from torch._prims_common import DeviceLikeType
from torch.autograd.graph import Node as _Node
from torch.cuda import _POOL_HANDLE
from torch.fx.node import Node as FxNode
from torch.package import PackageExporter
from torch.storage import TypedStorage, UntypedStorage
@ -2289,7 +2290,7 @@ class _CUDAGraph:
def __new__(cls, keep_graph: _bool = ...) -> Self: ...
def capture_begin(
self,
pool: tuple[_int, _int] | None = ...,
pool: _POOL_HANDLE | None = ...,
capture_error_mode: str = "global",
) -> None: ...
def capture_end(self) -> None: ...
@ -2297,7 +2298,7 @@ class _CUDAGraph:
def register_generator_state(self, Generator) -> None: ...
def replay(self) -> None: ...
def reset(self) -> None: ...
def pool(self) -> tuple[_int, _int]: ...
def pool(self) -> _POOL_HANDLE: ...
def enable_debug_mode(self) -> None: ...
def debug_dump(self, debug_path: str) -> None: ...
def raw_cuda_graph(self) -> _int: ...

View File

@ -90,6 +90,7 @@ if TYPE_CHECKING:
from torch._guards import CompileId
from torch._inductor.utils import InputType
from torch.cuda import _POOL_HANDLE
from torch.types import _bool
StorageWeakRefPointer = int
@ -817,7 +818,7 @@ class CUDAGraphNode:
id: GraphID,
parent: Optional[CUDAGraphNode],
inputs: list[InputType],
cuda_graphs_pool: tuple[int, int],
cuda_graphs_pool: _POOL_HANDLE,
device_index: int,
stack_traces: Optional[StackTraces],
stream: torch.cuda.Stream,
@ -1228,6 +1229,7 @@ class CUDAGraphNode:
def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType:
"Record the model"
assert self.graph is not None
def static_input_iter() -> Generator[torch.Tensor, None, None]:
for i in self.wrapped_function.static_input_idxs:
@ -1310,13 +1312,11 @@ class CUDAGraphNode:
self.output_storage_alias.append(UnaliasedStorage)
continue
(
torch._check(
o.is_cuda or o.untyped_storage().data_ptr() == 0,
lambda: (
"Expected all cuda outputs in cuda graph recording. Non cuda output "
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
),
torch._check(
o.is_cuda or o.untyped_storage().data_ptr() == 0,
lambda: (
"Expected all cuda outputs in cuda graph recording. Non cuda output "
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
),
)

View File

@ -4,8 +4,8 @@ import inspect
import itertools
import warnings
from collections import OrderedDict
from typing import Any, Optional
from typing_extensions import deprecated
from typing import Any, Callable, Optional, TypeVar
from typing_extensions import Concatenate, deprecated, ParamSpec
import torch
import torch._C as _C
@ -29,6 +29,10 @@ __all__ = [
# This is incremented in FunctionMeta during class definition
AUTOGRAD_FUNCTION_COUNTER = itertools.count()
_T = TypeVar("_T")
_R = TypeVar("_R")
_P = ParamSpec("_P")
# Formerly known as: _ContextMethodMixin
class FunctionCtx:
@ -595,11 +599,13 @@ def _is_setup_context_defined(fn):
return fn != _SingleLevelFunction.setup_context
def once_differentiable(fn):
def once_differentiable(
fn: Callable[Concatenate[_T, _P], _R],
) -> Callable[Concatenate[_T, _P], _R]:
@functools.wraps(fn)
def wrapper(ctx, *args):
def wrapper(ctx: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R:
with torch.no_grad():
outputs = fn(ctx, *args)
outputs = fn(ctx, *args, **kwargs)
if not torch.is_grad_enabled():
return outputs
@ -620,12 +626,14 @@ def once_differentiable(fn):
return outputs
if not isinstance(outputs, tuple):
outputs = (outputs,)
outputs_ = (outputs,)
else:
outputs_ = outputs
err_fn = _functions.DelayedError(
b"trying to differentiate twice a function that was marked "
b"with @once_differentiable",
len(outputs),
len(outputs_),
)
# Create aliases of each output that has requires_grad=True. We need
@ -637,7 +645,7 @@ def once_differentiable(fn):
var.requires_grad = True
return var
return err_fn(*[fake_requires_grad(v) for v in outputs])
return err_fn(*[fake_requires_grad(v) for v in outputs_]) # type: ignore[return-value]
return wrapper

View File

@ -18,7 +18,7 @@ import threading
import traceback
import warnings
from functools import lru_cache
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, cast, NewType, Optional, TYPE_CHECKING, Union
import torch
import torch._C
@ -1777,6 +1777,9 @@ def _compile_kernel(
from . import amp, jiterator, nvtx, profiler, sparse, tunable
_POOL_HANDLE = NewType("_POOL_HANDLE", tuple[int, int])
__all__ = [
# Typed storage and tensors
"BFloat16Storage",

View File

@ -1,12 +1,34 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import gc
import typing
from typing import Callable, Optional, overload, TYPE_CHECKING, Union
from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar
import torch
from torch import Tensor
if TYPE_CHECKING:
# importing _POOL_HANDLE at runtime toplevel causes an import cycle
from torch.cuda import _POOL_HANDLE
from .._utils import _dummy_type
__all__ = [
"is_current_stream_capturing",
"graph_pool_handle",
"CUDAGraph",
"graph",
"make_graphed_callables",
]
_R = TypeVar("_R")
_P = ParamSpec("_P")
if not hasattr(torch._C, "_CudaStreamBase"):
# Define dummy base classes
torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
@ -22,7 +44,7 @@ from torch._C import ( # noqa: F401
)
def is_current_stream_capturing():
def is_current_stream_capturing() -> bool:
r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
If a CUDA context does not exist on the current device, returns False without initializing the context.
@ -31,7 +53,7 @@ def is_current_stream_capturing():
# Python shim helps Sphinx process docstrings more reliably.
def graph_pool_handle():
def graph_pool_handle() -> _POOL_HANDLE:
r"""Return an opaque token representing the id of a graph memory pool.
See :ref:`Graph memory management<graph-memory-management>`.
@ -39,7 +61,7 @@ def graph_pool_handle():
.. warning::
This API is in beta and may change in future releases.
"""
return _graph_pool_handle()
return torch.cuda._POOL_HANDLE(_graph_pool_handle())
# Python shim helps Sphinx process docstrings more reliably.
@ -70,10 +92,12 @@ class CUDAGraph(torch._C._CUDAGraph):
"""
def __new__(cls, keep_graph=False):
def __new__(cls, keep_graph: bool = False) -> Self:
return super().__new__(cls, keep_graph)
def capture_begin(self, pool=None, capture_error_mode="global"):
def capture_begin(
self, pool: Optional[_POOL_HANDLE] = None, capture_error_mode: str = "global"
) -> None:
r"""Begin capturing CUDA work on the current stream.
Typically, you shouldn't call ``capture_begin`` yourself.
@ -92,7 +116,7 @@ class CUDAGraph(torch._C._CUDAGraph):
""" # noqa: B950
super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
def capture_end(self):
def capture_end(self) -> None:
r"""End CUDA graph capture on the current stream.
After ``capture_end``, ``replay`` may be called on this instance.
@ -103,7 +127,7 @@ class CUDAGraph(torch._C._CUDAGraph):
"""
super().capture_end()
def instantiate(self):
def instantiate(self) -> None:
r"""Instantiate the CUDA graph. Will be called by
``capture_end`` if ``keep_graph=False``, or by ``replay`` if
``keep_graph=True`` and ``instantiate`` has not already been
@ -112,15 +136,15 @@ class CUDAGraph(torch._C._CUDAGraph):
"""
super().instantiate()
def replay(self):
def replay(self) -> None:
r"""Replay the CUDA work captured by this graph."""
super().replay()
def reset(self):
def reset(self) -> None:
r"""Delete the graph currently held by this instance."""
super().reset()
def pool(self):
def pool(self) -> _POOL_HANDLE:
r"""Return an opaque token representing the id of this graph's memory pool.
This id can optionally be passed to another graph's ``capture_begin``,
@ -128,11 +152,11 @@ class CUDAGraph(torch._C._CUDAGraph):
"""
return super().pool()
def enable_debug_mode(self):
def enable_debug_mode(self) -> None:
r"""Enable debugging mode for CUDAGraph.debug_dump."""
return super().enable_debug_mode()
def debug_dump(self, debug_path):
def debug_dump(self, debug_path: str) -> None:
r"""
Arguments:
debug_path (required): Path to dump the graph to.
@ -142,7 +166,7 @@ class CUDAGraph(torch._C._CUDAGraph):
"""
return super().debug_dump(debug_path)
def raw_cuda_graph(self):
def raw_cuda_graph(self) -> int:
r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True.
See the following for APIs for how to manipulate this object: `Graph Managmement <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html>`_ and `cuda-python Graph Management bindings <https://nvidia.github.io/cuda-python/cuda-bindings/latest/module/runtime.html#graph-management>`_
@ -180,13 +204,13 @@ class graph:
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
""" # noqa: B950
default_capture_stream: typing.Optional["torch.cuda.Stream"] = None
default_capture_stream: Optional[torch.cuda.Stream] = None
def __init__(
self,
cuda_graph,
pool=None,
stream=None,
cuda_graph: CUDAGraph,
pool: Optional[_POOL_HANDLE] = None,
stream: Optional[torch.cuda.Stream] = None,
capture_error_mode: str = "global",
):
# Lazy-init of default_capture_stream helps avoid circular-import errors.
@ -195,7 +219,9 @@ class graph:
if self.__class__.default_capture_stream is None:
self.__class__.default_capture_stream = torch.cuda.Stream()
self.pool = () if pool is None else (pool,)
self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = (
() if pool is None else (pool,)
)
self.capture_stream = (
stream if stream is not None else self.__class__.default_capture_stream
)
@ -204,7 +230,7 @@ class graph:
self.cuda_graph = cuda_graph
self.capture_error_mode = capture_error_mode
def __enter__(self):
def __enter__(self) -> None:
# Free as much memory as we can for the graph
torch.cuda.synchronize()
gc.collect()
@ -215,18 +241,47 @@ class graph:
self.stream_ctx.__enter__()
self.cuda_graph.capture_begin(
*self.pool, capture_error_mode=self.capture_error_mode
# type: ignore[misc]
*self.pool,
capture_error_mode=self.capture_error_mode,
)
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, *args: object) -> None:
self.cuda_graph.capture_end()
self.stream_ctx.__exit__(exc_type, exc_value, traceback)
self.stream_ctx.__exit__(*args)
# returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
_ModuleOrCallable: TypeAlias = Union["torch.nn.Module", Callable[..., object]]
@overload
def make_graphed_callables(
callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
):
callables: _ModuleOrCallable,
sample_args: tuple[Tensor, ...],
num_warmup_iters: int = 3,
allow_unused_input: bool = False,
pool: Optional[_POOL_HANDLE] = None,
) -> _ModuleOrCallable: ...
@overload
def make_graphed_callables(
callables: tuple[_ModuleOrCallable, ...],
sample_args: tuple[tuple[Tensor, ...], ...],
num_warmup_iters: int = 3,
allow_unused_input: bool = False,
pool: Optional[_POOL_HANDLE] = None,
) -> tuple[_ModuleOrCallable, ...]: ...
def make_graphed_callables(
callables: Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]],
sample_args: Union[tuple[Tensor, ...], tuple[tuple[Tensor, ...], ...]],
num_warmup_iters: int = 3,
allow_unused_input: bool = False,
pool: Optional[_POOL_HANDLE] = None,
) -> Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]]:
r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
Each graphed callable's forward pass runs its source callable's
@ -300,14 +355,17 @@ def make_graphed_callables(
just_one_callable = False
_sample_args: tuple[tuple[Tensor, ...], ...]
if not isinstance(callables, tuple):
just_one_callable = True
callables = (callables,)
sample_args = (sample_args,)
_sample_args = (typing.cast(tuple[Tensor, ...], sample_args),)
else:
_sample_args = typing.cast(tuple[tuple[Tensor, ...], ...], sample_args)
flatten_sample_args = []
for c, args in zip(callables, sample_args):
for c, args in zip(callables, _sample_args):
if isinstance(c, torch.nn.Module):
assert (
len(c._backward_hooks) == 0
@ -352,7 +410,7 @@ def make_graphed_callables(
torch.cuda.synchronize()
with torch.cuda.stream(torch.cuda.Stream()):
for func, args, static_input_surface in zip(
callables, sample_args, per_callable_static_input_surfaces
callables, _sample_args, per_callable_static_input_surfaces
):
grad_inputs, outputs, outputs_grad = None, None, None
for _ in range(num_warmup_iters):
@ -382,11 +440,11 @@ def make_graphed_callables(
# Capture forward graphs
per_callable_static_outputs = []
per_callable_output_unflatten_spec = []
for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
for func, args, fwd_graph in zip(callables, _sample_args, fwd_graphs):
with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args)
func_outputs = func(*args)
flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs)
flatten_outputs, spec = torch.utils._pytree.tree_flatten(func_outputs)
per_callable_static_outputs.append(tuple(flatten_outputs))
per_callable_output_unflatten_spec.append(spec)
@ -438,19 +496,19 @@ def make_graphed_callables(
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
def make_graphed_autograd_function(
fwd_graph,
bwd_graph,
module_params,
len_user_args,
output_unflatten_spec,
static_input_surface,
static_outputs,
static_grad_outputs,
static_grad_inputs,
):
fwd_graph: CUDAGraph,
bwd_graph: CUDAGraph,
module_params: tuple[torch.nn.Parameter, ...],
len_user_args: int,
output_unflatten_spec: torch.utils._pytree.TreeSpec,
static_input_surface: tuple[Tensor, ...],
static_outputs: tuple[Tensor, ...],
static_grad_outputs: tuple[Optional[Tensor], ...],
static_grad_inputs: tuple[Tensor, ...],
) -> Callable[..., object]:
class Graphed(torch.autograd.Function):
@staticmethod
def forward(ctx, *inputs):
def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]:
# At this stage, only the user args may (potentially) be new tensors.
for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
@ -461,7 +519,7 @@ def make_graphed_callables(
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, *grads):
def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]:
assert len(grads) == len(static_grad_outputs)
for g, grad in zip(static_grad_outputs, grads):
if g is not None:
@ -477,7 +535,7 @@ def make_graphed_callables(
b.detach() if b is not None else b for b in static_grad_inputs
)
def functionalized(*user_args):
def functionalized(*user_args: object) -> object:
# Runs the autograd function with inputs == all inputs to the graph that might require grad
# (explicit user args + module parameters)
# Assumes module params didn't change since capture.
@ -488,7 +546,7 @@ def make_graphed_callables(
return functionalized
# Put together the final graphed callables
ret = []
ret: list[_ModuleOrCallable] = []
for i, func in enumerate(callables):
graphed = make_graphed_autograd_function(
fwd_graphs[i],
@ -504,20 +562,25 @@ def make_graphed_callables(
if isinstance(func, torch.nn.Module):
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
def new_fwd(*user_args):
def make_graphed_forward(
func: torch.nn.Module,
graph_training_state: bool,
graphed: Callable[_P, _R],
orig_fwd: Callable[_P, _R],
) -> Callable[_P, _R]:
def new_fwd(*user_args: _P.args, **user_kwargs: _P.kwargs) -> _R:
# If the module's training-or-eval state matches what we graphed,
# run the graph, otherwise run the original forward method
if func.training == graph_training_state:
return graphed(*user_args)
return graphed(*user_args, **user_kwargs)
else:
return orig_fwd(*user_args)
return orig_fwd(*user_args, **user_kwargs)
return new_fwd
func.forward = make_graphed_forward(
func, func.training, graphed, func.forward
) # type: ignore[assignment]
)
ret.append(func)
else:
ret.append(graphed)

View File

@ -28,7 +28,7 @@ class _LazyModule:
# NOTE: Add additional used imports here.
if TYPE_CHECKING:
import onnx
import onnx_ir # type: ignore[import-untyped]
import onnx_ir # type: ignore[import-untyped, import-not-found]
import onnxscript
import onnxscript._framework_apis.torch_2_8 as onnxscript_apis