Fix types in graphs.py (#158192)

Added type annotations for torch/cuda/graphs.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158192
Approved by: https://github.com/oulgen
This commit is contained in:
Aaron Orenstein
2025-07-13 11:51:16 -07:00
committed by PyTorch MergeBot
parent 011026205a
commit 250ae2531c
6 changed files with 145 additions and 70 deletions

View File

@ -40,6 +40,7 @@ from torch._C import (
) )
from torch._prims_common import DeviceLikeType from torch._prims_common import DeviceLikeType
from torch.autograd.graph import Node as _Node from torch.autograd.graph import Node as _Node
from torch.cuda import _POOL_HANDLE
from torch.fx.node import Node as FxNode from torch.fx.node import Node as FxNode
from torch.package import PackageExporter from torch.package import PackageExporter
from torch.storage import TypedStorage, UntypedStorage from torch.storage import TypedStorage, UntypedStorage
@ -2289,7 +2290,7 @@ class _CUDAGraph:
def __new__(cls, keep_graph: _bool = ...) -> Self: ... def __new__(cls, keep_graph: _bool = ...) -> Self: ...
def capture_begin( def capture_begin(
self, self,
pool: tuple[_int, _int] | None = ..., pool: _POOL_HANDLE | None = ...,
capture_error_mode: str = "global", capture_error_mode: str = "global",
) -> None: ... ) -> None: ...
def capture_end(self) -> None: ... def capture_end(self) -> None: ...
@ -2297,7 +2298,7 @@ class _CUDAGraph:
def register_generator_state(self, Generator) -> None: ... def register_generator_state(self, Generator) -> None: ...
def replay(self) -> None: ... def replay(self) -> None: ...
def reset(self) -> None: ... def reset(self) -> None: ...
def pool(self) -> tuple[_int, _int]: ... def pool(self) -> _POOL_HANDLE: ...
def enable_debug_mode(self) -> None: ... def enable_debug_mode(self) -> None: ...
def debug_dump(self, debug_path: str) -> None: ... def debug_dump(self, debug_path: str) -> None: ...
def raw_cuda_graph(self) -> _int: ... def raw_cuda_graph(self) -> _int: ...

View File

@ -90,6 +90,7 @@ if TYPE_CHECKING:
from torch._guards import CompileId from torch._guards import CompileId
from torch._inductor.utils import InputType from torch._inductor.utils import InputType
from torch.cuda import _POOL_HANDLE
from torch.types import _bool from torch.types import _bool
StorageWeakRefPointer = int StorageWeakRefPointer = int
@ -817,7 +818,7 @@ class CUDAGraphNode:
id: GraphID, id: GraphID,
parent: Optional[CUDAGraphNode], parent: Optional[CUDAGraphNode],
inputs: list[InputType], inputs: list[InputType],
cuda_graphs_pool: tuple[int, int], cuda_graphs_pool: _POOL_HANDLE,
device_index: int, device_index: int,
stack_traces: Optional[StackTraces], stack_traces: Optional[StackTraces],
stream: torch.cuda.Stream, stream: torch.cuda.Stream,
@ -1228,6 +1229,7 @@ class CUDAGraphNode:
def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType: def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType:
"Record the model" "Record the model"
assert self.graph is not None
def static_input_iter() -> Generator[torch.Tensor, None, None]: def static_input_iter() -> Generator[torch.Tensor, None, None]:
for i in self.wrapped_function.static_input_idxs: for i in self.wrapped_function.static_input_idxs:
@ -1310,13 +1312,11 @@ class CUDAGraphNode:
self.output_storage_alias.append(UnaliasedStorage) self.output_storage_alias.append(UnaliasedStorage)
continue continue
( torch._check(
torch._check( o.is_cuda or o.untyped_storage().data_ptr() == 0,
o.is_cuda or o.untyped_storage().data_ptr() == 0, lambda: (
lambda: ( "Expected all cuda outputs in cuda graph recording. Non cuda output "
"Expected all cuda outputs in cuda graph recording. Non cuda output " f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
),
), ),
) )

View File

@ -4,8 +4,8 @@ import inspect
import itertools import itertools
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional from typing import Any, Callable, Optional, TypeVar
from typing_extensions import deprecated from typing_extensions import Concatenate, deprecated, ParamSpec
import torch import torch
import torch._C as _C import torch._C as _C
@ -29,6 +29,10 @@ __all__ = [
# This is incremented in FunctionMeta during class definition # This is incremented in FunctionMeta during class definition
AUTOGRAD_FUNCTION_COUNTER = itertools.count() AUTOGRAD_FUNCTION_COUNTER = itertools.count()
_T = TypeVar("_T")
_R = TypeVar("_R")
_P = ParamSpec("_P")
# Formerly known as: _ContextMethodMixin # Formerly known as: _ContextMethodMixin
class FunctionCtx: class FunctionCtx:
@ -595,11 +599,13 @@ def _is_setup_context_defined(fn):
return fn != _SingleLevelFunction.setup_context return fn != _SingleLevelFunction.setup_context
def once_differentiable(fn): def once_differentiable(
fn: Callable[Concatenate[_T, _P], _R],
) -> Callable[Concatenate[_T, _P], _R]:
@functools.wraps(fn) @functools.wraps(fn)
def wrapper(ctx, *args): def wrapper(ctx: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R:
with torch.no_grad(): with torch.no_grad():
outputs = fn(ctx, *args) outputs = fn(ctx, *args, **kwargs)
if not torch.is_grad_enabled(): if not torch.is_grad_enabled():
return outputs return outputs
@ -620,12 +626,14 @@ def once_differentiable(fn):
return outputs return outputs
if not isinstance(outputs, tuple): if not isinstance(outputs, tuple):
outputs = (outputs,) outputs_ = (outputs,)
else:
outputs_ = outputs
err_fn = _functions.DelayedError( err_fn = _functions.DelayedError(
b"trying to differentiate twice a function that was marked " b"trying to differentiate twice a function that was marked "
b"with @once_differentiable", b"with @once_differentiable",
len(outputs), len(outputs_),
) )
# Create aliases of each output that has requires_grad=True. We need # Create aliases of each output that has requires_grad=True. We need
@ -637,7 +645,7 @@ def once_differentiable(fn):
var.requires_grad = True var.requires_grad = True
return var return var
return err_fn(*[fake_requires_grad(v) for v in outputs]) return err_fn(*[fake_requires_grad(v) for v in outputs_]) # type: ignore[return-value]
return wrapper return wrapper

View File

@ -18,7 +18,7 @@ import threading
import traceback import traceback
import warnings import warnings
from functools import lru_cache from functools import lru_cache
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union from typing import Any, Callable, cast, NewType, Optional, TYPE_CHECKING, Union
import torch import torch
import torch._C import torch._C
@ -1777,6 +1777,9 @@ def _compile_kernel(
from . import amp, jiterator, nvtx, profiler, sparse, tunable from . import amp, jiterator, nvtx, profiler, sparse, tunable
_POOL_HANDLE = NewType("_POOL_HANDLE", tuple[int, int])
__all__ = [ __all__ = [
# Typed storage and tensors # Typed storage and tensors
"BFloat16Storage", "BFloat16Storage",

View File

@ -1,12 +1,34 @@
# mypy: allow-untyped-defs from __future__ import annotations
import gc import gc
import typing import typing
from typing import Callable, Optional, overload, TYPE_CHECKING, Union
from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar
import torch import torch
from torch import Tensor
if TYPE_CHECKING:
# importing _POOL_HANDLE at runtime toplevel causes an import cycle
from torch.cuda import _POOL_HANDLE
from .._utils import _dummy_type from .._utils import _dummy_type
__all__ = [
"is_current_stream_capturing",
"graph_pool_handle",
"CUDAGraph",
"graph",
"make_graphed_callables",
]
_R = TypeVar("_R")
_P = ParamSpec("_P")
if not hasattr(torch._C, "_CudaStreamBase"): if not hasattr(torch._C, "_CudaStreamBase"):
# Define dummy base classes # Define dummy base classes
torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph") torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
@ -22,7 +44,7 @@ from torch._C import ( # noqa: F401
) )
def is_current_stream_capturing(): def is_current_stream_capturing() -> bool:
r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise. r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
If a CUDA context does not exist on the current device, returns False without initializing the context. If a CUDA context does not exist on the current device, returns False without initializing the context.
@ -31,7 +53,7 @@ def is_current_stream_capturing():
# Python shim helps Sphinx process docstrings more reliably. # Python shim helps Sphinx process docstrings more reliably.
def graph_pool_handle(): def graph_pool_handle() -> _POOL_HANDLE:
r"""Return an opaque token representing the id of a graph memory pool. r"""Return an opaque token representing the id of a graph memory pool.
See :ref:`Graph memory management<graph-memory-management>`. See :ref:`Graph memory management<graph-memory-management>`.
@ -39,7 +61,7 @@ def graph_pool_handle():
.. warning:: .. warning::
This API is in beta and may change in future releases. This API is in beta and may change in future releases.
""" """
return _graph_pool_handle() return torch.cuda._POOL_HANDLE(_graph_pool_handle())
# Python shim helps Sphinx process docstrings more reliably. # Python shim helps Sphinx process docstrings more reliably.
@ -70,10 +92,12 @@ class CUDAGraph(torch._C._CUDAGraph):
""" """
def __new__(cls, keep_graph=False): def __new__(cls, keep_graph: bool = False) -> Self:
return super().__new__(cls, keep_graph) return super().__new__(cls, keep_graph)
def capture_begin(self, pool=None, capture_error_mode="global"): def capture_begin(
self, pool: Optional[_POOL_HANDLE] = None, capture_error_mode: str = "global"
) -> None:
r"""Begin capturing CUDA work on the current stream. r"""Begin capturing CUDA work on the current stream.
Typically, you shouldn't call ``capture_begin`` yourself. Typically, you shouldn't call ``capture_begin`` yourself.
@ -92,7 +116,7 @@ class CUDAGraph(torch._C._CUDAGraph):
""" # noqa: B950 """ # noqa: B950
super().capture_begin(pool=pool, capture_error_mode=capture_error_mode) super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
def capture_end(self): def capture_end(self) -> None:
r"""End CUDA graph capture on the current stream. r"""End CUDA graph capture on the current stream.
After ``capture_end``, ``replay`` may be called on this instance. After ``capture_end``, ``replay`` may be called on this instance.
@ -103,7 +127,7 @@ class CUDAGraph(torch._C._CUDAGraph):
""" """
super().capture_end() super().capture_end()
def instantiate(self): def instantiate(self) -> None:
r"""Instantiate the CUDA graph. Will be called by r"""Instantiate the CUDA graph. Will be called by
``capture_end`` if ``keep_graph=False``, or by ``replay`` if ``capture_end`` if ``keep_graph=False``, or by ``replay`` if
``keep_graph=True`` and ``instantiate`` has not already been ``keep_graph=True`` and ``instantiate`` has not already been
@ -112,15 +136,15 @@ class CUDAGraph(torch._C._CUDAGraph):
""" """
super().instantiate() super().instantiate()
def replay(self): def replay(self) -> None:
r"""Replay the CUDA work captured by this graph.""" r"""Replay the CUDA work captured by this graph."""
super().replay() super().replay()
def reset(self): def reset(self) -> None:
r"""Delete the graph currently held by this instance.""" r"""Delete the graph currently held by this instance."""
super().reset() super().reset()
def pool(self): def pool(self) -> _POOL_HANDLE:
r"""Return an opaque token representing the id of this graph's memory pool. r"""Return an opaque token representing the id of this graph's memory pool.
This id can optionally be passed to another graph's ``capture_begin``, This id can optionally be passed to another graph's ``capture_begin``,
@ -128,11 +152,11 @@ class CUDAGraph(torch._C._CUDAGraph):
""" """
return super().pool() return super().pool()
def enable_debug_mode(self): def enable_debug_mode(self) -> None:
r"""Enable debugging mode for CUDAGraph.debug_dump.""" r"""Enable debugging mode for CUDAGraph.debug_dump."""
return super().enable_debug_mode() return super().enable_debug_mode()
def debug_dump(self, debug_path): def debug_dump(self, debug_path: str) -> None:
r""" r"""
Arguments: Arguments:
debug_path (required): Path to dump the graph to. debug_path (required): Path to dump the graph to.
@ -142,7 +166,7 @@ class CUDAGraph(torch._C._CUDAGraph):
""" """
return super().debug_dump(debug_path) return super().debug_dump(debug_path)
def raw_cuda_graph(self): def raw_cuda_graph(self) -> int:
r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True. r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True.
See the following for APIs for how to manipulate this object: `Graph Managmement <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html>`_ and `cuda-python Graph Management bindings <https://nvidia.github.io/cuda-python/cuda-bindings/latest/module/runtime.html#graph-management>`_ See the following for APIs for how to manipulate this object: `Graph Managmement <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html>`_ and `cuda-python Graph Management bindings <https://nvidia.github.io/cuda-python/cuda-bindings/latest/module/runtime.html#graph-management>`_
@ -180,13 +204,13 @@ class graph:
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
""" # noqa: B950 """ # noqa: B950
default_capture_stream: typing.Optional["torch.cuda.Stream"] = None default_capture_stream: Optional[torch.cuda.Stream] = None
def __init__( def __init__(
self, self,
cuda_graph, cuda_graph: CUDAGraph,
pool=None, pool: Optional[_POOL_HANDLE] = None,
stream=None, stream: Optional[torch.cuda.Stream] = None,
capture_error_mode: str = "global", capture_error_mode: str = "global",
): ):
# Lazy-init of default_capture_stream helps avoid circular-import errors. # Lazy-init of default_capture_stream helps avoid circular-import errors.
@ -195,7 +219,9 @@ class graph:
if self.__class__.default_capture_stream is None: if self.__class__.default_capture_stream is None:
self.__class__.default_capture_stream = torch.cuda.Stream() self.__class__.default_capture_stream = torch.cuda.Stream()
self.pool = () if pool is None else (pool,) self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = (
() if pool is None else (pool,)
)
self.capture_stream = ( self.capture_stream = (
stream if stream is not None else self.__class__.default_capture_stream stream if stream is not None else self.__class__.default_capture_stream
) )
@ -204,7 +230,7 @@ class graph:
self.cuda_graph = cuda_graph self.cuda_graph = cuda_graph
self.capture_error_mode = capture_error_mode self.capture_error_mode = capture_error_mode
def __enter__(self): def __enter__(self) -> None:
# Free as much memory as we can for the graph # Free as much memory as we can for the graph
torch.cuda.synchronize() torch.cuda.synchronize()
gc.collect() gc.collect()
@ -215,18 +241,47 @@ class graph:
self.stream_ctx.__enter__() self.stream_ctx.__enter__()
self.cuda_graph.capture_begin( self.cuda_graph.capture_begin(
*self.pool, capture_error_mode=self.capture_error_mode # type: ignore[misc]
*self.pool,
capture_error_mode=self.capture_error_mode,
) )
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, *args: object) -> None:
self.cuda_graph.capture_end() self.cuda_graph.capture_end()
self.stream_ctx.__exit__(exc_type, exc_value, traceback) self.stream_ctx.__exit__(*args)
# returning None should propagate exceptions from either capture_end or stream_ctx.__exit__() # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
_ModuleOrCallable: TypeAlias = Union["torch.nn.Module", Callable[..., object]]
@overload
def make_graphed_callables( def make_graphed_callables(
callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None callables: _ModuleOrCallable,
): sample_args: tuple[Tensor, ...],
num_warmup_iters: int = 3,
allow_unused_input: bool = False,
pool: Optional[_POOL_HANDLE] = None,
) -> _ModuleOrCallable: ...
@overload
def make_graphed_callables(
callables: tuple[_ModuleOrCallable, ...],
sample_args: tuple[tuple[Tensor, ...], ...],
num_warmup_iters: int = 3,
allow_unused_input: bool = False,
pool: Optional[_POOL_HANDLE] = None,
) -> tuple[_ModuleOrCallable, ...]: ...
def make_graphed_callables(
callables: Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]],
sample_args: Union[tuple[Tensor, ...], tuple[tuple[Tensor, ...], ...]],
num_warmup_iters: int = 3,
allow_unused_input: bool = False,
pool: Optional[_POOL_HANDLE] = None,
) -> Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]]:
r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions. r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
Each graphed callable's forward pass runs its source callable's Each graphed callable's forward pass runs its source callable's
@ -300,14 +355,17 @@ def make_graphed_callables(
just_one_callable = False just_one_callable = False
_sample_args: tuple[tuple[Tensor, ...], ...]
if not isinstance(callables, tuple): if not isinstance(callables, tuple):
just_one_callable = True just_one_callable = True
callables = (callables,) callables = (callables,)
sample_args = (sample_args,) _sample_args = (typing.cast(tuple[Tensor, ...], sample_args),)
else:
_sample_args = typing.cast(tuple[tuple[Tensor, ...], ...], sample_args)
flatten_sample_args = [] flatten_sample_args = []
for c, args in zip(callables, sample_args): for c, args in zip(callables, _sample_args):
if isinstance(c, torch.nn.Module): if isinstance(c, torch.nn.Module):
assert ( assert (
len(c._backward_hooks) == 0 len(c._backward_hooks) == 0
@ -352,7 +410,7 @@ def make_graphed_callables(
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.stream(torch.cuda.Stream()): with torch.cuda.stream(torch.cuda.Stream()):
for func, args, static_input_surface in zip( for func, args, static_input_surface in zip(
callables, sample_args, per_callable_static_input_surfaces callables, _sample_args, per_callable_static_input_surfaces
): ):
grad_inputs, outputs, outputs_grad = None, None, None grad_inputs, outputs, outputs_grad = None, None, None
for _ in range(num_warmup_iters): for _ in range(num_warmup_iters):
@ -382,11 +440,11 @@ def make_graphed_callables(
# Capture forward graphs # Capture forward graphs
per_callable_static_outputs = [] per_callable_static_outputs = []
per_callable_output_unflatten_spec = [] per_callable_output_unflatten_spec = []
for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): for func, args, fwd_graph in zip(callables, _sample_args, fwd_graphs):
with torch.cuda.graph(fwd_graph, pool=mempool): with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args) func_outputs = func(*args)
flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs) flatten_outputs, spec = torch.utils._pytree.tree_flatten(func_outputs)
per_callable_static_outputs.append(tuple(flatten_outputs)) per_callable_static_outputs.append(tuple(flatten_outputs))
per_callable_output_unflatten_spec.append(spec) per_callable_output_unflatten_spec.append(spec)
@ -438,19 +496,19 @@ def make_graphed_callables(
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
def make_graphed_autograd_function( def make_graphed_autograd_function(
fwd_graph, fwd_graph: CUDAGraph,
bwd_graph, bwd_graph: CUDAGraph,
module_params, module_params: tuple[torch.nn.Parameter, ...],
len_user_args, len_user_args: int,
output_unflatten_spec, output_unflatten_spec: torch.utils._pytree.TreeSpec,
static_input_surface, static_input_surface: tuple[Tensor, ...],
static_outputs, static_outputs: tuple[Tensor, ...],
static_grad_outputs, static_grad_outputs: tuple[Optional[Tensor], ...],
static_grad_inputs, static_grad_inputs: tuple[Tensor, ...],
): ) -> Callable[..., object]:
class Graphed(torch.autograd.Function): class Graphed(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, *inputs): def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]:
# At this stage, only the user args may (potentially) be new tensors. # At this stage, only the user args may (potentially) be new tensors.
for i in range(len_user_args): for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
@ -461,7 +519,7 @@ def make_graphed_callables(
@staticmethod @staticmethod
@torch.autograd.function.once_differentiable @torch.autograd.function.once_differentiable
def backward(ctx, *grads): def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]:
assert len(grads) == len(static_grad_outputs) assert len(grads) == len(static_grad_outputs)
for g, grad in zip(static_grad_outputs, grads): for g, grad in zip(static_grad_outputs, grads):
if g is not None: if g is not None:
@ -477,7 +535,7 @@ def make_graphed_callables(
b.detach() if b is not None else b for b in static_grad_inputs b.detach() if b is not None else b for b in static_grad_inputs
) )
def functionalized(*user_args): def functionalized(*user_args: object) -> object:
# Runs the autograd function with inputs == all inputs to the graph that might require grad # Runs the autograd function with inputs == all inputs to the graph that might require grad
# (explicit user args + module parameters) # (explicit user args + module parameters)
# Assumes module params didn't change since capture. # Assumes module params didn't change since capture.
@ -488,7 +546,7 @@ def make_graphed_callables(
return functionalized return functionalized
# Put together the final graphed callables # Put together the final graphed callables
ret = [] ret: list[_ModuleOrCallable] = []
for i, func in enumerate(callables): for i, func in enumerate(callables):
graphed = make_graphed_autograd_function( graphed = make_graphed_autograd_function(
fwd_graphs[i], fwd_graphs[i],
@ -504,20 +562,25 @@ def make_graphed_callables(
if isinstance(func, torch.nn.Module): if isinstance(func, torch.nn.Module):
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): def make_graphed_forward(
def new_fwd(*user_args): func: torch.nn.Module,
graph_training_state: bool,
graphed: Callable[_P, _R],
orig_fwd: Callable[_P, _R],
) -> Callable[_P, _R]:
def new_fwd(*user_args: _P.args, **user_kwargs: _P.kwargs) -> _R:
# If the module's training-or-eval state matches what we graphed, # If the module's training-or-eval state matches what we graphed,
# run the graph, otherwise run the original forward method # run the graph, otherwise run the original forward method
if func.training == graph_training_state: if func.training == graph_training_state:
return graphed(*user_args) return graphed(*user_args, **user_kwargs)
else: else:
return orig_fwd(*user_args) return orig_fwd(*user_args, **user_kwargs)
return new_fwd return new_fwd
func.forward = make_graphed_forward( func.forward = make_graphed_forward(
func, func.training, graphed, func.forward func, func.training, graphed, func.forward
) # type: ignore[assignment] )
ret.append(func) ret.append(func)
else: else:
ret.append(graphed) ret.append(graphed)

View File

@ -28,7 +28,7 @@ class _LazyModule:
# NOTE: Add additional used imports here. # NOTE: Add additional used imports here.
if TYPE_CHECKING: if TYPE_CHECKING:
import onnx import onnx
import onnx_ir # type: ignore[import-untyped] import onnx_ir # type: ignore[import-untyped, import-not-found]
import onnxscript import onnxscript
import onnxscript._framework_apis.torch_2_8 as onnxscript_apis import onnxscript._framework_apis.torch_2_8 as onnxscript_apis