mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix types in graphs.py (#158192)
Added type annotations for torch/cuda/graphs.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/158192 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
011026205a
commit
250ae2531c
@ -40,6 +40,7 @@ from torch._C import (
|
|||||||
)
|
)
|
||||||
from torch._prims_common import DeviceLikeType
|
from torch._prims_common import DeviceLikeType
|
||||||
from torch.autograd.graph import Node as _Node
|
from torch.autograd.graph import Node as _Node
|
||||||
|
from torch.cuda import _POOL_HANDLE
|
||||||
from torch.fx.node import Node as FxNode
|
from torch.fx.node import Node as FxNode
|
||||||
from torch.package import PackageExporter
|
from torch.package import PackageExporter
|
||||||
from torch.storage import TypedStorage, UntypedStorage
|
from torch.storage import TypedStorage, UntypedStorage
|
||||||
@ -2289,7 +2290,7 @@ class _CUDAGraph:
|
|||||||
def __new__(cls, keep_graph: _bool = ...) -> Self: ...
|
def __new__(cls, keep_graph: _bool = ...) -> Self: ...
|
||||||
def capture_begin(
|
def capture_begin(
|
||||||
self,
|
self,
|
||||||
pool: tuple[_int, _int] | None = ...,
|
pool: _POOL_HANDLE | None = ...,
|
||||||
capture_error_mode: str = "global",
|
capture_error_mode: str = "global",
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
def capture_end(self) -> None: ...
|
def capture_end(self) -> None: ...
|
||||||
@ -2297,7 +2298,7 @@ class _CUDAGraph:
|
|||||||
def register_generator_state(self, Generator) -> None: ...
|
def register_generator_state(self, Generator) -> None: ...
|
||||||
def replay(self) -> None: ...
|
def replay(self) -> None: ...
|
||||||
def reset(self) -> None: ...
|
def reset(self) -> None: ...
|
||||||
def pool(self) -> tuple[_int, _int]: ...
|
def pool(self) -> _POOL_HANDLE: ...
|
||||||
def enable_debug_mode(self) -> None: ...
|
def enable_debug_mode(self) -> None: ...
|
||||||
def debug_dump(self, debug_path: str) -> None: ...
|
def debug_dump(self, debug_path: str) -> None: ...
|
||||||
def raw_cuda_graph(self) -> _int: ...
|
def raw_cuda_graph(self) -> _int: ...
|
||||||
|
@ -90,6 +90,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from torch._guards import CompileId
|
from torch._guards import CompileId
|
||||||
from torch._inductor.utils import InputType
|
from torch._inductor.utils import InputType
|
||||||
|
from torch.cuda import _POOL_HANDLE
|
||||||
from torch.types import _bool
|
from torch.types import _bool
|
||||||
|
|
||||||
StorageWeakRefPointer = int
|
StorageWeakRefPointer = int
|
||||||
@ -817,7 +818,7 @@ class CUDAGraphNode:
|
|||||||
id: GraphID,
|
id: GraphID,
|
||||||
parent: Optional[CUDAGraphNode],
|
parent: Optional[CUDAGraphNode],
|
||||||
inputs: list[InputType],
|
inputs: list[InputType],
|
||||||
cuda_graphs_pool: tuple[int, int],
|
cuda_graphs_pool: _POOL_HANDLE,
|
||||||
device_index: int,
|
device_index: int,
|
||||||
stack_traces: Optional[StackTraces],
|
stack_traces: Optional[StackTraces],
|
||||||
stream: torch.cuda.Stream,
|
stream: torch.cuda.Stream,
|
||||||
@ -1228,6 +1229,7 @@ class CUDAGraphNode:
|
|||||||
|
|
||||||
def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType:
|
def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType:
|
||||||
"Record the model"
|
"Record the model"
|
||||||
|
assert self.graph is not None
|
||||||
|
|
||||||
def static_input_iter() -> Generator[torch.Tensor, None, None]:
|
def static_input_iter() -> Generator[torch.Tensor, None, None]:
|
||||||
for i in self.wrapped_function.static_input_idxs:
|
for i in self.wrapped_function.static_input_idxs:
|
||||||
@ -1310,13 +1312,11 @@ class CUDAGraphNode:
|
|||||||
self.output_storage_alias.append(UnaliasedStorage)
|
self.output_storage_alias.append(UnaliasedStorage)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
(
|
torch._check(
|
||||||
torch._check(
|
o.is_cuda or o.untyped_storage().data_ptr() == 0,
|
||||||
o.is_cuda or o.untyped_storage().data_ptr() == 0,
|
lambda: (
|
||||||
lambda: (
|
"Expected all cuda outputs in cuda graph recording. Non cuda output "
|
||||||
"Expected all cuda outputs in cuda graph recording. Non cuda output "
|
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
|
||||||
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4,8 +4,8 @@ import inspect
|
|||||||
import itertools
|
import itertools
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Optional
|
from typing import Any, Callable, Optional, TypeVar
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import Concatenate, deprecated, ParamSpec
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._C as _C
|
import torch._C as _C
|
||||||
@ -29,6 +29,10 @@ __all__ = [
|
|||||||
# This is incremented in FunctionMeta during class definition
|
# This is incremented in FunctionMeta during class definition
|
||||||
AUTOGRAD_FUNCTION_COUNTER = itertools.count()
|
AUTOGRAD_FUNCTION_COUNTER = itertools.count()
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
_R = TypeVar("_R")
|
||||||
|
_P = ParamSpec("_P")
|
||||||
|
|
||||||
|
|
||||||
# Formerly known as: _ContextMethodMixin
|
# Formerly known as: _ContextMethodMixin
|
||||||
class FunctionCtx:
|
class FunctionCtx:
|
||||||
@ -595,11 +599,13 @@ def _is_setup_context_defined(fn):
|
|||||||
return fn != _SingleLevelFunction.setup_context
|
return fn != _SingleLevelFunction.setup_context
|
||||||
|
|
||||||
|
|
||||||
def once_differentiable(fn):
|
def once_differentiable(
|
||||||
|
fn: Callable[Concatenate[_T, _P], _R],
|
||||||
|
) -> Callable[Concatenate[_T, _P], _R]:
|
||||||
@functools.wraps(fn)
|
@functools.wraps(fn)
|
||||||
def wrapper(ctx, *args):
|
def wrapper(ctx: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = fn(ctx, *args)
|
outputs = fn(ctx, *args, **kwargs)
|
||||||
|
|
||||||
if not torch.is_grad_enabled():
|
if not torch.is_grad_enabled():
|
||||||
return outputs
|
return outputs
|
||||||
@ -620,12 +626,14 @@ def once_differentiable(fn):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
if not isinstance(outputs, tuple):
|
if not isinstance(outputs, tuple):
|
||||||
outputs = (outputs,)
|
outputs_ = (outputs,)
|
||||||
|
else:
|
||||||
|
outputs_ = outputs
|
||||||
|
|
||||||
err_fn = _functions.DelayedError(
|
err_fn = _functions.DelayedError(
|
||||||
b"trying to differentiate twice a function that was marked "
|
b"trying to differentiate twice a function that was marked "
|
||||||
b"with @once_differentiable",
|
b"with @once_differentiable",
|
||||||
len(outputs),
|
len(outputs_),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create aliases of each output that has requires_grad=True. We need
|
# Create aliases of each output that has requires_grad=True. We need
|
||||||
@ -637,7 +645,7 @@ def once_differentiable(fn):
|
|||||||
var.requires_grad = True
|
var.requires_grad = True
|
||||||
return var
|
return var
|
||||||
|
|
||||||
return err_fn(*[fake_requires_grad(v) for v in outputs])
|
return err_fn(*[fake_requires_grad(v) for v in outputs_]) # type: ignore[return-value]
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ import threading
|
|||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
|
from typing import Any, Callable, cast, NewType, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._C
|
import torch._C
|
||||||
@ -1777,6 +1777,9 @@ def _compile_kernel(
|
|||||||
from . import amp, jiterator, nvtx, profiler, sparse, tunable
|
from . import amp, jiterator, nvtx, profiler, sparse, tunable
|
||||||
|
|
||||||
|
|
||||||
|
_POOL_HANDLE = NewType("_POOL_HANDLE", tuple[int, int])
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Typed storage and tensors
|
# Typed storage and tensors
|
||||||
"BFloat16Storage",
|
"BFloat16Storage",
|
||||||
|
@ -1,12 +1,34 @@
|
|||||||
# mypy: allow-untyped-defs
|
from __future__ import annotations
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import typing
|
import typing
|
||||||
|
from typing import Callable, Optional, overload, TYPE_CHECKING, Union
|
||||||
|
from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# importing _POOL_HANDLE at runtime toplevel causes an import cycle
|
||||||
|
from torch.cuda import _POOL_HANDLE
|
||||||
|
|
||||||
from .._utils import _dummy_type
|
from .._utils import _dummy_type
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"is_current_stream_capturing",
|
||||||
|
"graph_pool_handle",
|
||||||
|
"CUDAGraph",
|
||||||
|
"graph",
|
||||||
|
"make_graphed_callables",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
_R = TypeVar("_R")
|
||||||
|
_P = ParamSpec("_P")
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(torch._C, "_CudaStreamBase"):
|
if not hasattr(torch._C, "_CudaStreamBase"):
|
||||||
# Define dummy base classes
|
# Define dummy base classes
|
||||||
torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
|
torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
|
||||||
@ -22,7 +44,7 @@ from torch._C import ( # noqa: F401
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_current_stream_capturing():
|
def is_current_stream_capturing() -> bool:
|
||||||
r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
|
r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
|
||||||
|
|
||||||
If a CUDA context does not exist on the current device, returns False without initializing the context.
|
If a CUDA context does not exist on the current device, returns False without initializing the context.
|
||||||
@ -31,7 +53,7 @@ def is_current_stream_capturing():
|
|||||||
|
|
||||||
|
|
||||||
# Python shim helps Sphinx process docstrings more reliably.
|
# Python shim helps Sphinx process docstrings more reliably.
|
||||||
def graph_pool_handle():
|
def graph_pool_handle() -> _POOL_HANDLE:
|
||||||
r"""Return an opaque token representing the id of a graph memory pool.
|
r"""Return an opaque token representing the id of a graph memory pool.
|
||||||
|
|
||||||
See :ref:`Graph memory management<graph-memory-management>`.
|
See :ref:`Graph memory management<graph-memory-management>`.
|
||||||
@ -39,7 +61,7 @@ def graph_pool_handle():
|
|||||||
.. warning::
|
.. warning::
|
||||||
This API is in beta and may change in future releases.
|
This API is in beta and may change in future releases.
|
||||||
"""
|
"""
|
||||||
return _graph_pool_handle()
|
return torch.cuda._POOL_HANDLE(_graph_pool_handle())
|
||||||
|
|
||||||
|
|
||||||
# Python shim helps Sphinx process docstrings more reliably.
|
# Python shim helps Sphinx process docstrings more reliably.
|
||||||
@ -70,10 +92,12 @@ class CUDAGraph(torch._C._CUDAGraph):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, keep_graph=False):
|
def __new__(cls, keep_graph: bool = False) -> Self:
|
||||||
return super().__new__(cls, keep_graph)
|
return super().__new__(cls, keep_graph)
|
||||||
|
|
||||||
def capture_begin(self, pool=None, capture_error_mode="global"):
|
def capture_begin(
|
||||||
|
self, pool: Optional[_POOL_HANDLE] = None, capture_error_mode: str = "global"
|
||||||
|
) -> None:
|
||||||
r"""Begin capturing CUDA work on the current stream.
|
r"""Begin capturing CUDA work on the current stream.
|
||||||
|
|
||||||
Typically, you shouldn't call ``capture_begin`` yourself.
|
Typically, you shouldn't call ``capture_begin`` yourself.
|
||||||
@ -92,7 +116,7 @@ class CUDAGraph(torch._C._CUDAGraph):
|
|||||||
""" # noqa: B950
|
""" # noqa: B950
|
||||||
super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
|
super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
|
||||||
|
|
||||||
def capture_end(self):
|
def capture_end(self) -> None:
|
||||||
r"""End CUDA graph capture on the current stream.
|
r"""End CUDA graph capture on the current stream.
|
||||||
|
|
||||||
After ``capture_end``, ``replay`` may be called on this instance.
|
After ``capture_end``, ``replay`` may be called on this instance.
|
||||||
@ -103,7 +127,7 @@ class CUDAGraph(torch._C._CUDAGraph):
|
|||||||
"""
|
"""
|
||||||
super().capture_end()
|
super().capture_end()
|
||||||
|
|
||||||
def instantiate(self):
|
def instantiate(self) -> None:
|
||||||
r"""Instantiate the CUDA graph. Will be called by
|
r"""Instantiate the CUDA graph. Will be called by
|
||||||
``capture_end`` if ``keep_graph=False``, or by ``replay`` if
|
``capture_end`` if ``keep_graph=False``, or by ``replay`` if
|
||||||
``keep_graph=True`` and ``instantiate`` has not already been
|
``keep_graph=True`` and ``instantiate`` has not already been
|
||||||
@ -112,15 +136,15 @@ class CUDAGraph(torch._C._CUDAGraph):
|
|||||||
"""
|
"""
|
||||||
super().instantiate()
|
super().instantiate()
|
||||||
|
|
||||||
def replay(self):
|
def replay(self) -> None:
|
||||||
r"""Replay the CUDA work captured by this graph."""
|
r"""Replay the CUDA work captured by this graph."""
|
||||||
super().replay()
|
super().replay()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self) -> None:
|
||||||
r"""Delete the graph currently held by this instance."""
|
r"""Delete the graph currently held by this instance."""
|
||||||
super().reset()
|
super().reset()
|
||||||
|
|
||||||
def pool(self):
|
def pool(self) -> _POOL_HANDLE:
|
||||||
r"""Return an opaque token representing the id of this graph's memory pool.
|
r"""Return an opaque token representing the id of this graph's memory pool.
|
||||||
|
|
||||||
This id can optionally be passed to another graph's ``capture_begin``,
|
This id can optionally be passed to another graph's ``capture_begin``,
|
||||||
@ -128,11 +152,11 @@ class CUDAGraph(torch._C._CUDAGraph):
|
|||||||
"""
|
"""
|
||||||
return super().pool()
|
return super().pool()
|
||||||
|
|
||||||
def enable_debug_mode(self):
|
def enable_debug_mode(self) -> None:
|
||||||
r"""Enable debugging mode for CUDAGraph.debug_dump."""
|
r"""Enable debugging mode for CUDAGraph.debug_dump."""
|
||||||
return super().enable_debug_mode()
|
return super().enable_debug_mode()
|
||||||
|
|
||||||
def debug_dump(self, debug_path):
|
def debug_dump(self, debug_path: str) -> None:
|
||||||
r"""
|
r"""
|
||||||
Arguments:
|
Arguments:
|
||||||
debug_path (required): Path to dump the graph to.
|
debug_path (required): Path to dump the graph to.
|
||||||
@ -142,7 +166,7 @@ class CUDAGraph(torch._C._CUDAGraph):
|
|||||||
"""
|
"""
|
||||||
return super().debug_dump(debug_path)
|
return super().debug_dump(debug_path)
|
||||||
|
|
||||||
def raw_cuda_graph(self):
|
def raw_cuda_graph(self) -> int:
|
||||||
r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True.
|
r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True.
|
||||||
|
|
||||||
See the following for APIs for how to manipulate this object: `Graph Managmement <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html>`_ and `cuda-python Graph Management bindings <https://nvidia.github.io/cuda-python/cuda-bindings/latest/module/runtime.html#graph-management>`_
|
See the following for APIs for how to manipulate this object: `Graph Managmement <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html>`_ and `cuda-python Graph Management bindings <https://nvidia.github.io/cuda-python/cuda-bindings/latest/module/runtime.html#graph-management>`_
|
||||||
@ -180,13 +204,13 @@ class graph:
|
|||||||
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
|
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
|
||||||
""" # noqa: B950
|
""" # noqa: B950
|
||||||
|
|
||||||
default_capture_stream: typing.Optional["torch.cuda.Stream"] = None
|
default_capture_stream: Optional[torch.cuda.Stream] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cuda_graph,
|
cuda_graph: CUDAGraph,
|
||||||
pool=None,
|
pool: Optional[_POOL_HANDLE] = None,
|
||||||
stream=None,
|
stream: Optional[torch.cuda.Stream] = None,
|
||||||
capture_error_mode: str = "global",
|
capture_error_mode: str = "global",
|
||||||
):
|
):
|
||||||
# Lazy-init of default_capture_stream helps avoid circular-import errors.
|
# Lazy-init of default_capture_stream helps avoid circular-import errors.
|
||||||
@ -195,7 +219,9 @@ class graph:
|
|||||||
if self.__class__.default_capture_stream is None:
|
if self.__class__.default_capture_stream is None:
|
||||||
self.__class__.default_capture_stream = torch.cuda.Stream()
|
self.__class__.default_capture_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
self.pool = () if pool is None else (pool,)
|
self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = (
|
||||||
|
() if pool is None else (pool,)
|
||||||
|
)
|
||||||
self.capture_stream = (
|
self.capture_stream = (
|
||||||
stream if stream is not None else self.__class__.default_capture_stream
|
stream if stream is not None else self.__class__.default_capture_stream
|
||||||
)
|
)
|
||||||
@ -204,7 +230,7 @@ class graph:
|
|||||||
self.cuda_graph = cuda_graph
|
self.cuda_graph = cuda_graph
|
||||||
self.capture_error_mode = capture_error_mode
|
self.capture_error_mode = capture_error_mode
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> None:
|
||||||
# Free as much memory as we can for the graph
|
# Free as much memory as we can for the graph
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -215,18 +241,47 @@ class graph:
|
|||||||
self.stream_ctx.__enter__()
|
self.stream_ctx.__enter__()
|
||||||
|
|
||||||
self.cuda_graph.capture_begin(
|
self.cuda_graph.capture_begin(
|
||||||
*self.pool, capture_error_mode=self.capture_error_mode
|
# type: ignore[misc]
|
||||||
|
*self.pool,
|
||||||
|
capture_error_mode=self.capture_error_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, *args: object) -> None:
|
||||||
self.cuda_graph.capture_end()
|
self.cuda_graph.capture_end()
|
||||||
self.stream_ctx.__exit__(exc_type, exc_value, traceback)
|
self.stream_ctx.__exit__(*args)
|
||||||
# returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
|
# returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
|
||||||
|
|
||||||
|
|
||||||
|
_ModuleOrCallable: TypeAlias = Union["torch.nn.Module", Callable[..., object]]
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
def make_graphed_callables(
|
def make_graphed_callables(
|
||||||
callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
|
callables: _ModuleOrCallable,
|
||||||
):
|
sample_args: tuple[Tensor, ...],
|
||||||
|
num_warmup_iters: int = 3,
|
||||||
|
allow_unused_input: bool = False,
|
||||||
|
pool: Optional[_POOL_HANDLE] = None,
|
||||||
|
) -> _ModuleOrCallable: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def make_graphed_callables(
|
||||||
|
callables: tuple[_ModuleOrCallable, ...],
|
||||||
|
sample_args: tuple[tuple[Tensor, ...], ...],
|
||||||
|
num_warmup_iters: int = 3,
|
||||||
|
allow_unused_input: bool = False,
|
||||||
|
pool: Optional[_POOL_HANDLE] = None,
|
||||||
|
) -> tuple[_ModuleOrCallable, ...]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def make_graphed_callables(
|
||||||
|
callables: Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]],
|
||||||
|
sample_args: Union[tuple[Tensor, ...], tuple[tuple[Tensor, ...], ...]],
|
||||||
|
num_warmup_iters: int = 3,
|
||||||
|
allow_unused_input: bool = False,
|
||||||
|
pool: Optional[_POOL_HANDLE] = None,
|
||||||
|
) -> Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]]:
|
||||||
r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
|
r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
|
||||||
|
|
||||||
Each graphed callable's forward pass runs its source callable's
|
Each graphed callable's forward pass runs its source callable's
|
||||||
@ -300,14 +355,17 @@ def make_graphed_callables(
|
|||||||
|
|
||||||
just_one_callable = False
|
just_one_callable = False
|
||||||
|
|
||||||
|
_sample_args: tuple[tuple[Tensor, ...], ...]
|
||||||
if not isinstance(callables, tuple):
|
if not isinstance(callables, tuple):
|
||||||
just_one_callable = True
|
just_one_callable = True
|
||||||
callables = (callables,)
|
callables = (callables,)
|
||||||
sample_args = (sample_args,)
|
_sample_args = (typing.cast(tuple[Tensor, ...], sample_args),)
|
||||||
|
else:
|
||||||
|
_sample_args = typing.cast(tuple[tuple[Tensor, ...], ...], sample_args)
|
||||||
|
|
||||||
flatten_sample_args = []
|
flatten_sample_args = []
|
||||||
|
|
||||||
for c, args in zip(callables, sample_args):
|
for c, args in zip(callables, _sample_args):
|
||||||
if isinstance(c, torch.nn.Module):
|
if isinstance(c, torch.nn.Module):
|
||||||
assert (
|
assert (
|
||||||
len(c._backward_hooks) == 0
|
len(c._backward_hooks) == 0
|
||||||
@ -352,7 +410,7 @@ def make_graphed_callables(
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
with torch.cuda.stream(torch.cuda.Stream()):
|
with torch.cuda.stream(torch.cuda.Stream()):
|
||||||
for func, args, static_input_surface in zip(
|
for func, args, static_input_surface in zip(
|
||||||
callables, sample_args, per_callable_static_input_surfaces
|
callables, _sample_args, per_callable_static_input_surfaces
|
||||||
):
|
):
|
||||||
grad_inputs, outputs, outputs_grad = None, None, None
|
grad_inputs, outputs, outputs_grad = None, None, None
|
||||||
for _ in range(num_warmup_iters):
|
for _ in range(num_warmup_iters):
|
||||||
@ -382,11 +440,11 @@ def make_graphed_callables(
|
|||||||
# Capture forward graphs
|
# Capture forward graphs
|
||||||
per_callable_static_outputs = []
|
per_callable_static_outputs = []
|
||||||
per_callable_output_unflatten_spec = []
|
per_callable_output_unflatten_spec = []
|
||||||
for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
|
for func, args, fwd_graph in zip(callables, _sample_args, fwd_graphs):
|
||||||
with torch.cuda.graph(fwd_graph, pool=mempool):
|
with torch.cuda.graph(fwd_graph, pool=mempool):
|
||||||
outputs = func(*args)
|
func_outputs = func(*args)
|
||||||
|
|
||||||
flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs)
|
flatten_outputs, spec = torch.utils._pytree.tree_flatten(func_outputs)
|
||||||
per_callable_static_outputs.append(tuple(flatten_outputs))
|
per_callable_static_outputs.append(tuple(flatten_outputs))
|
||||||
per_callable_output_unflatten_spec.append(spec)
|
per_callable_output_unflatten_spec.append(spec)
|
||||||
|
|
||||||
@ -438,19 +496,19 @@ def make_graphed_callables(
|
|||||||
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
|
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
|
||||||
|
|
||||||
def make_graphed_autograd_function(
|
def make_graphed_autograd_function(
|
||||||
fwd_graph,
|
fwd_graph: CUDAGraph,
|
||||||
bwd_graph,
|
bwd_graph: CUDAGraph,
|
||||||
module_params,
|
module_params: tuple[torch.nn.Parameter, ...],
|
||||||
len_user_args,
|
len_user_args: int,
|
||||||
output_unflatten_spec,
|
output_unflatten_spec: torch.utils._pytree.TreeSpec,
|
||||||
static_input_surface,
|
static_input_surface: tuple[Tensor, ...],
|
||||||
static_outputs,
|
static_outputs: tuple[Tensor, ...],
|
||||||
static_grad_outputs,
|
static_grad_outputs: tuple[Optional[Tensor], ...],
|
||||||
static_grad_inputs,
|
static_grad_inputs: tuple[Tensor, ...],
|
||||||
):
|
) -> Callable[..., object]:
|
||||||
class Graphed(torch.autograd.Function):
|
class Graphed(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, *inputs):
|
def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]:
|
||||||
# At this stage, only the user args may (potentially) be new tensors.
|
# At this stage, only the user args may (potentially) be new tensors.
|
||||||
for i in range(len_user_args):
|
for i in range(len_user_args):
|
||||||
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
|
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
|
||||||
@ -461,7 +519,7 @@ def make_graphed_callables(
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.autograd.function.once_differentiable
|
@torch.autograd.function.once_differentiable
|
||||||
def backward(ctx, *grads):
|
def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]:
|
||||||
assert len(grads) == len(static_grad_outputs)
|
assert len(grads) == len(static_grad_outputs)
|
||||||
for g, grad in zip(static_grad_outputs, grads):
|
for g, grad in zip(static_grad_outputs, grads):
|
||||||
if g is not None:
|
if g is not None:
|
||||||
@ -477,7 +535,7 @@ def make_graphed_callables(
|
|||||||
b.detach() if b is not None else b for b in static_grad_inputs
|
b.detach() if b is not None else b for b in static_grad_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
def functionalized(*user_args):
|
def functionalized(*user_args: object) -> object:
|
||||||
# Runs the autograd function with inputs == all inputs to the graph that might require grad
|
# Runs the autograd function with inputs == all inputs to the graph that might require grad
|
||||||
# (explicit user args + module parameters)
|
# (explicit user args + module parameters)
|
||||||
# Assumes module params didn't change since capture.
|
# Assumes module params didn't change since capture.
|
||||||
@ -488,7 +546,7 @@ def make_graphed_callables(
|
|||||||
return functionalized
|
return functionalized
|
||||||
|
|
||||||
# Put together the final graphed callables
|
# Put together the final graphed callables
|
||||||
ret = []
|
ret: list[_ModuleOrCallable] = []
|
||||||
for i, func in enumerate(callables):
|
for i, func in enumerate(callables):
|
||||||
graphed = make_graphed_autograd_function(
|
graphed = make_graphed_autograd_function(
|
||||||
fwd_graphs[i],
|
fwd_graphs[i],
|
||||||
@ -504,20 +562,25 @@ def make_graphed_callables(
|
|||||||
|
|
||||||
if isinstance(func, torch.nn.Module):
|
if isinstance(func, torch.nn.Module):
|
||||||
|
|
||||||
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
|
def make_graphed_forward(
|
||||||
def new_fwd(*user_args):
|
func: torch.nn.Module,
|
||||||
|
graph_training_state: bool,
|
||||||
|
graphed: Callable[_P, _R],
|
||||||
|
orig_fwd: Callable[_P, _R],
|
||||||
|
) -> Callable[_P, _R]:
|
||||||
|
def new_fwd(*user_args: _P.args, **user_kwargs: _P.kwargs) -> _R:
|
||||||
# If the module's training-or-eval state matches what we graphed,
|
# If the module's training-or-eval state matches what we graphed,
|
||||||
# run the graph, otherwise run the original forward method
|
# run the graph, otherwise run the original forward method
|
||||||
if func.training == graph_training_state:
|
if func.training == graph_training_state:
|
||||||
return graphed(*user_args)
|
return graphed(*user_args, **user_kwargs)
|
||||||
else:
|
else:
|
||||||
return orig_fwd(*user_args)
|
return orig_fwd(*user_args, **user_kwargs)
|
||||||
|
|
||||||
return new_fwd
|
return new_fwd
|
||||||
|
|
||||||
func.forward = make_graphed_forward(
|
func.forward = make_graphed_forward(
|
||||||
func, func.training, graphed, func.forward
|
func, func.training, graphed, func.forward
|
||||||
) # type: ignore[assignment]
|
)
|
||||||
ret.append(func)
|
ret.append(func)
|
||||||
else:
|
else:
|
||||||
ret.append(graphed)
|
ret.append(graphed)
|
||||||
|
@ -28,7 +28,7 @@ class _LazyModule:
|
|||||||
# NOTE: Add additional used imports here.
|
# NOTE: Add additional used imports here.
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import onnx
|
import onnx
|
||||||
import onnx_ir # type: ignore[import-untyped]
|
import onnx_ir # type: ignore[import-untyped, import-not-found]
|
||||||
import onnxscript
|
import onnxscript
|
||||||
import onnxscript._framework_apis.torch_2_8 as onnxscript_apis
|
import onnxscript._framework_apis.torch_2_8 as onnxscript_apis
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user