mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix types in graphs.py (#158192)
Added type annotations for torch/cuda/graphs.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/158192 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
011026205a
commit
250ae2531c
@ -40,6 +40,7 @@ from torch._C import (
|
||||
)
|
||||
from torch._prims_common import DeviceLikeType
|
||||
from torch.autograd.graph import Node as _Node
|
||||
from torch.cuda import _POOL_HANDLE
|
||||
from torch.fx.node import Node as FxNode
|
||||
from torch.package import PackageExporter
|
||||
from torch.storage import TypedStorage, UntypedStorage
|
||||
@ -2289,7 +2290,7 @@ class _CUDAGraph:
|
||||
def __new__(cls, keep_graph: _bool = ...) -> Self: ...
|
||||
def capture_begin(
|
||||
self,
|
||||
pool: tuple[_int, _int] | None = ...,
|
||||
pool: _POOL_HANDLE | None = ...,
|
||||
capture_error_mode: str = "global",
|
||||
) -> None: ...
|
||||
def capture_end(self) -> None: ...
|
||||
@ -2297,7 +2298,7 @@ class _CUDAGraph:
|
||||
def register_generator_state(self, Generator) -> None: ...
|
||||
def replay(self) -> None: ...
|
||||
def reset(self) -> None: ...
|
||||
def pool(self) -> tuple[_int, _int]: ...
|
||||
def pool(self) -> _POOL_HANDLE: ...
|
||||
def enable_debug_mode(self) -> None: ...
|
||||
def debug_dump(self, debug_path: str) -> None: ...
|
||||
def raw_cuda_graph(self) -> _int: ...
|
||||
|
@ -90,6 +90,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from torch._guards import CompileId
|
||||
from torch._inductor.utils import InputType
|
||||
from torch.cuda import _POOL_HANDLE
|
||||
from torch.types import _bool
|
||||
|
||||
StorageWeakRefPointer = int
|
||||
@ -817,7 +818,7 @@ class CUDAGraphNode:
|
||||
id: GraphID,
|
||||
parent: Optional[CUDAGraphNode],
|
||||
inputs: list[InputType],
|
||||
cuda_graphs_pool: tuple[int, int],
|
||||
cuda_graphs_pool: _POOL_HANDLE,
|
||||
device_index: int,
|
||||
stack_traces: Optional[StackTraces],
|
||||
stream: torch.cuda.Stream,
|
||||
@ -1228,6 +1229,7 @@ class CUDAGraphNode:
|
||||
|
||||
def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType:
|
||||
"Record the model"
|
||||
assert self.graph is not None
|
||||
|
||||
def static_input_iter() -> Generator[torch.Tensor, None, None]:
|
||||
for i in self.wrapped_function.static_input_idxs:
|
||||
@ -1310,13 +1312,11 @@ class CUDAGraphNode:
|
||||
self.output_storage_alias.append(UnaliasedStorage)
|
||||
continue
|
||||
|
||||
(
|
||||
torch._check(
|
||||
o.is_cuda or o.untyped_storage().data_ptr() == 0,
|
||||
lambda: (
|
||||
"Expected all cuda outputs in cuda graph recording. Non cuda output "
|
||||
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
|
||||
),
|
||||
torch._check(
|
||||
o.is_cuda or o.untyped_storage().data_ptr() == 0,
|
||||
lambda: (
|
||||
"Expected all cuda outputs in cuda graph recording. Non cuda output "
|
||||
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -4,8 +4,8 @@ import inspect
|
||||
import itertools
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import deprecated
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from typing_extensions import Concatenate, deprecated, ParamSpec
|
||||
|
||||
import torch
|
||||
import torch._C as _C
|
||||
@ -29,6 +29,10 @@ __all__ = [
|
||||
# This is incremented in FunctionMeta during class definition
|
||||
AUTOGRAD_FUNCTION_COUNTER = itertools.count()
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_R = TypeVar("_R")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
# Formerly known as: _ContextMethodMixin
|
||||
class FunctionCtx:
|
||||
@ -595,11 +599,13 @@ def _is_setup_context_defined(fn):
|
||||
return fn != _SingleLevelFunction.setup_context
|
||||
|
||||
|
||||
def once_differentiable(fn):
|
||||
def once_differentiable(
|
||||
fn: Callable[Concatenate[_T, _P], _R],
|
||||
) -> Callable[Concatenate[_T, _P], _R]:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(ctx, *args):
|
||||
def wrapper(ctx: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
with torch.no_grad():
|
||||
outputs = fn(ctx, *args)
|
||||
outputs = fn(ctx, *args, **kwargs)
|
||||
|
||||
if not torch.is_grad_enabled():
|
||||
return outputs
|
||||
@ -620,12 +626,14 @@ def once_differentiable(fn):
|
||||
return outputs
|
||||
|
||||
if not isinstance(outputs, tuple):
|
||||
outputs = (outputs,)
|
||||
outputs_ = (outputs,)
|
||||
else:
|
||||
outputs_ = outputs
|
||||
|
||||
err_fn = _functions.DelayedError(
|
||||
b"trying to differentiate twice a function that was marked "
|
||||
b"with @once_differentiable",
|
||||
len(outputs),
|
||||
len(outputs_),
|
||||
)
|
||||
|
||||
# Create aliases of each output that has requires_grad=True. We need
|
||||
@ -637,7 +645,7 @@ def once_differentiable(fn):
|
||||
var.requires_grad = True
|
||||
return var
|
||||
|
||||
return err_fn(*[fake_requires_grad(v) for v in outputs])
|
||||
return err_fn(*[fake_requires_grad(v) for v in outputs_]) # type: ignore[return-value]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
@ -18,7 +18,7 @@ import threading
|
||||
import traceback
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, cast, NewType, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch._C
|
||||
@ -1777,6 +1777,9 @@ def _compile_kernel(
|
||||
from . import amp, jiterator, nvtx, profiler, sparse, tunable
|
||||
|
||||
|
||||
_POOL_HANDLE = NewType("_POOL_HANDLE", tuple[int, int])
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Typed storage and tensors
|
||||
"BFloat16Storage",
|
||||
|
@ -1,12 +1,34 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import typing
|
||||
from typing import Callable, Optional, overload, TYPE_CHECKING, Union
|
||||
from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# importing _POOL_HANDLE at runtime toplevel causes an import cycle
|
||||
from torch.cuda import _POOL_HANDLE
|
||||
|
||||
from .._utils import _dummy_type
|
||||
|
||||
|
||||
__all__ = [
|
||||
"is_current_stream_capturing",
|
||||
"graph_pool_handle",
|
||||
"CUDAGraph",
|
||||
"graph",
|
||||
"make_graphed_callables",
|
||||
]
|
||||
|
||||
|
||||
_R = TypeVar("_R")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
if not hasattr(torch._C, "_CudaStreamBase"):
|
||||
# Define dummy base classes
|
||||
torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
|
||||
@ -22,7 +44,7 @@ from torch._C import ( # noqa: F401
|
||||
)
|
||||
|
||||
|
||||
def is_current_stream_capturing():
|
||||
def is_current_stream_capturing() -> bool:
|
||||
r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
|
||||
|
||||
If a CUDA context does not exist on the current device, returns False without initializing the context.
|
||||
@ -31,7 +53,7 @@ def is_current_stream_capturing():
|
||||
|
||||
|
||||
# Python shim helps Sphinx process docstrings more reliably.
|
||||
def graph_pool_handle():
|
||||
def graph_pool_handle() -> _POOL_HANDLE:
|
||||
r"""Return an opaque token representing the id of a graph memory pool.
|
||||
|
||||
See :ref:`Graph memory management<graph-memory-management>`.
|
||||
@ -39,7 +61,7 @@ def graph_pool_handle():
|
||||
.. warning::
|
||||
This API is in beta and may change in future releases.
|
||||
"""
|
||||
return _graph_pool_handle()
|
||||
return torch.cuda._POOL_HANDLE(_graph_pool_handle())
|
||||
|
||||
|
||||
# Python shim helps Sphinx process docstrings more reliably.
|
||||
@ -70,10 +92,12 @@ class CUDAGraph(torch._C._CUDAGraph):
|
||||
|
||||
"""
|
||||
|
||||
def __new__(cls, keep_graph=False):
|
||||
def __new__(cls, keep_graph: bool = False) -> Self:
|
||||
return super().__new__(cls, keep_graph)
|
||||
|
||||
def capture_begin(self, pool=None, capture_error_mode="global"):
|
||||
def capture_begin(
|
||||
self, pool: Optional[_POOL_HANDLE] = None, capture_error_mode: str = "global"
|
||||
) -> None:
|
||||
r"""Begin capturing CUDA work on the current stream.
|
||||
|
||||
Typically, you shouldn't call ``capture_begin`` yourself.
|
||||
@ -92,7 +116,7 @@ class CUDAGraph(torch._C._CUDAGraph):
|
||||
""" # noqa: B950
|
||||
super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
|
||||
|
||||
def capture_end(self):
|
||||
def capture_end(self) -> None:
|
||||
r"""End CUDA graph capture on the current stream.
|
||||
|
||||
After ``capture_end``, ``replay`` may be called on this instance.
|
||||
@ -103,7 +127,7 @@ class CUDAGraph(torch._C._CUDAGraph):
|
||||
"""
|
||||
super().capture_end()
|
||||
|
||||
def instantiate(self):
|
||||
def instantiate(self) -> None:
|
||||
r"""Instantiate the CUDA graph. Will be called by
|
||||
``capture_end`` if ``keep_graph=False``, or by ``replay`` if
|
||||
``keep_graph=True`` and ``instantiate`` has not already been
|
||||
@ -112,15 +136,15 @@ class CUDAGraph(torch._C._CUDAGraph):
|
||||
"""
|
||||
super().instantiate()
|
||||
|
||||
def replay(self):
|
||||
def replay(self) -> None:
|
||||
r"""Replay the CUDA work captured by this graph."""
|
||||
super().replay()
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
r"""Delete the graph currently held by this instance."""
|
||||
super().reset()
|
||||
|
||||
def pool(self):
|
||||
def pool(self) -> _POOL_HANDLE:
|
||||
r"""Return an opaque token representing the id of this graph's memory pool.
|
||||
|
||||
This id can optionally be passed to another graph's ``capture_begin``,
|
||||
@ -128,11 +152,11 @@ class CUDAGraph(torch._C._CUDAGraph):
|
||||
"""
|
||||
return super().pool()
|
||||
|
||||
def enable_debug_mode(self):
|
||||
def enable_debug_mode(self) -> None:
|
||||
r"""Enable debugging mode for CUDAGraph.debug_dump."""
|
||||
return super().enable_debug_mode()
|
||||
|
||||
def debug_dump(self, debug_path):
|
||||
def debug_dump(self, debug_path: str) -> None:
|
||||
r"""
|
||||
Arguments:
|
||||
debug_path (required): Path to dump the graph to.
|
||||
@ -142,7 +166,7 @@ class CUDAGraph(torch._C._CUDAGraph):
|
||||
"""
|
||||
return super().debug_dump(debug_path)
|
||||
|
||||
def raw_cuda_graph(self):
|
||||
def raw_cuda_graph(self) -> int:
|
||||
r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True.
|
||||
|
||||
See the following for APIs for how to manipulate this object: `Graph Managmement <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html>`_ and `cuda-python Graph Management bindings <https://nvidia.github.io/cuda-python/cuda-bindings/latest/module/runtime.html#graph-management>`_
|
||||
@ -180,13 +204,13 @@ class graph:
|
||||
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
|
||||
""" # noqa: B950
|
||||
|
||||
default_capture_stream: typing.Optional["torch.cuda.Stream"] = None
|
||||
default_capture_stream: Optional[torch.cuda.Stream] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cuda_graph,
|
||||
pool=None,
|
||||
stream=None,
|
||||
cuda_graph: CUDAGraph,
|
||||
pool: Optional[_POOL_HANDLE] = None,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
capture_error_mode: str = "global",
|
||||
):
|
||||
# Lazy-init of default_capture_stream helps avoid circular-import errors.
|
||||
@ -195,7 +219,9 @@ class graph:
|
||||
if self.__class__.default_capture_stream is None:
|
||||
self.__class__.default_capture_stream = torch.cuda.Stream()
|
||||
|
||||
self.pool = () if pool is None else (pool,)
|
||||
self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = (
|
||||
() if pool is None else (pool,)
|
||||
)
|
||||
self.capture_stream = (
|
||||
stream if stream is not None else self.__class__.default_capture_stream
|
||||
)
|
||||
@ -204,7 +230,7 @@ class graph:
|
||||
self.cuda_graph = cuda_graph
|
||||
self.capture_error_mode = capture_error_mode
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> None:
|
||||
# Free as much memory as we can for the graph
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
@ -215,18 +241,47 @@ class graph:
|
||||
self.stream_ctx.__enter__()
|
||||
|
||||
self.cuda_graph.capture_begin(
|
||||
*self.pool, capture_error_mode=self.capture_error_mode
|
||||
# type: ignore[misc]
|
||||
*self.pool,
|
||||
capture_error_mode=self.capture_error_mode,
|
||||
)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(self, *args: object) -> None:
|
||||
self.cuda_graph.capture_end()
|
||||
self.stream_ctx.__exit__(exc_type, exc_value, traceback)
|
||||
self.stream_ctx.__exit__(*args)
|
||||
# returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
|
||||
|
||||
|
||||
_ModuleOrCallable: TypeAlias = Union["torch.nn.Module", Callable[..., object]]
|
||||
|
||||
|
||||
@overload
|
||||
def make_graphed_callables(
|
||||
callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
|
||||
):
|
||||
callables: _ModuleOrCallable,
|
||||
sample_args: tuple[Tensor, ...],
|
||||
num_warmup_iters: int = 3,
|
||||
allow_unused_input: bool = False,
|
||||
pool: Optional[_POOL_HANDLE] = None,
|
||||
) -> _ModuleOrCallable: ...
|
||||
|
||||
|
||||
@overload
|
||||
def make_graphed_callables(
|
||||
callables: tuple[_ModuleOrCallable, ...],
|
||||
sample_args: tuple[tuple[Tensor, ...], ...],
|
||||
num_warmup_iters: int = 3,
|
||||
allow_unused_input: bool = False,
|
||||
pool: Optional[_POOL_HANDLE] = None,
|
||||
) -> tuple[_ModuleOrCallable, ...]: ...
|
||||
|
||||
|
||||
def make_graphed_callables(
|
||||
callables: Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]],
|
||||
sample_args: Union[tuple[Tensor, ...], tuple[tuple[Tensor, ...], ...]],
|
||||
num_warmup_iters: int = 3,
|
||||
allow_unused_input: bool = False,
|
||||
pool: Optional[_POOL_HANDLE] = None,
|
||||
) -> Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]]:
|
||||
r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
|
||||
|
||||
Each graphed callable's forward pass runs its source callable's
|
||||
@ -300,14 +355,17 @@ def make_graphed_callables(
|
||||
|
||||
just_one_callable = False
|
||||
|
||||
_sample_args: tuple[tuple[Tensor, ...], ...]
|
||||
if not isinstance(callables, tuple):
|
||||
just_one_callable = True
|
||||
callables = (callables,)
|
||||
sample_args = (sample_args,)
|
||||
_sample_args = (typing.cast(tuple[Tensor, ...], sample_args),)
|
||||
else:
|
||||
_sample_args = typing.cast(tuple[tuple[Tensor, ...], ...], sample_args)
|
||||
|
||||
flatten_sample_args = []
|
||||
|
||||
for c, args in zip(callables, sample_args):
|
||||
for c, args in zip(callables, _sample_args):
|
||||
if isinstance(c, torch.nn.Module):
|
||||
assert (
|
||||
len(c._backward_hooks) == 0
|
||||
@ -352,7 +410,7 @@ def make_graphed_callables(
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.stream(torch.cuda.Stream()):
|
||||
for func, args, static_input_surface in zip(
|
||||
callables, sample_args, per_callable_static_input_surfaces
|
||||
callables, _sample_args, per_callable_static_input_surfaces
|
||||
):
|
||||
grad_inputs, outputs, outputs_grad = None, None, None
|
||||
for _ in range(num_warmup_iters):
|
||||
@ -382,11 +440,11 @@ def make_graphed_callables(
|
||||
# Capture forward graphs
|
||||
per_callable_static_outputs = []
|
||||
per_callable_output_unflatten_spec = []
|
||||
for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
|
||||
for func, args, fwd_graph in zip(callables, _sample_args, fwd_graphs):
|
||||
with torch.cuda.graph(fwd_graph, pool=mempool):
|
||||
outputs = func(*args)
|
||||
func_outputs = func(*args)
|
||||
|
||||
flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs)
|
||||
flatten_outputs, spec = torch.utils._pytree.tree_flatten(func_outputs)
|
||||
per_callable_static_outputs.append(tuple(flatten_outputs))
|
||||
per_callable_output_unflatten_spec.append(spec)
|
||||
|
||||
@ -438,19 +496,19 @@ def make_graphed_callables(
|
||||
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
|
||||
|
||||
def make_graphed_autograd_function(
|
||||
fwd_graph,
|
||||
bwd_graph,
|
||||
module_params,
|
||||
len_user_args,
|
||||
output_unflatten_spec,
|
||||
static_input_surface,
|
||||
static_outputs,
|
||||
static_grad_outputs,
|
||||
static_grad_inputs,
|
||||
):
|
||||
fwd_graph: CUDAGraph,
|
||||
bwd_graph: CUDAGraph,
|
||||
module_params: tuple[torch.nn.Parameter, ...],
|
||||
len_user_args: int,
|
||||
output_unflatten_spec: torch.utils._pytree.TreeSpec,
|
||||
static_input_surface: tuple[Tensor, ...],
|
||||
static_outputs: tuple[Tensor, ...],
|
||||
static_grad_outputs: tuple[Optional[Tensor], ...],
|
||||
static_grad_inputs: tuple[Tensor, ...],
|
||||
) -> Callable[..., object]:
|
||||
class Graphed(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, *inputs):
|
||||
def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]:
|
||||
# At this stage, only the user args may (potentially) be new tensors.
|
||||
for i in range(len_user_args):
|
||||
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
|
||||
@ -461,7 +519,7 @@ def make_graphed_callables(
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, *grads):
|
||||
def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]:
|
||||
assert len(grads) == len(static_grad_outputs)
|
||||
for g, grad in zip(static_grad_outputs, grads):
|
||||
if g is not None:
|
||||
@ -477,7 +535,7 @@ def make_graphed_callables(
|
||||
b.detach() if b is not None else b for b in static_grad_inputs
|
||||
)
|
||||
|
||||
def functionalized(*user_args):
|
||||
def functionalized(*user_args: object) -> object:
|
||||
# Runs the autograd function with inputs == all inputs to the graph that might require grad
|
||||
# (explicit user args + module parameters)
|
||||
# Assumes module params didn't change since capture.
|
||||
@ -488,7 +546,7 @@ def make_graphed_callables(
|
||||
return functionalized
|
||||
|
||||
# Put together the final graphed callables
|
||||
ret = []
|
||||
ret: list[_ModuleOrCallable] = []
|
||||
for i, func in enumerate(callables):
|
||||
graphed = make_graphed_autograd_function(
|
||||
fwd_graphs[i],
|
||||
@ -504,20 +562,25 @@ def make_graphed_callables(
|
||||
|
||||
if isinstance(func, torch.nn.Module):
|
||||
|
||||
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
|
||||
def new_fwd(*user_args):
|
||||
def make_graphed_forward(
|
||||
func: torch.nn.Module,
|
||||
graph_training_state: bool,
|
||||
graphed: Callable[_P, _R],
|
||||
orig_fwd: Callable[_P, _R],
|
||||
) -> Callable[_P, _R]:
|
||||
def new_fwd(*user_args: _P.args, **user_kwargs: _P.kwargs) -> _R:
|
||||
# If the module's training-or-eval state matches what we graphed,
|
||||
# run the graph, otherwise run the original forward method
|
||||
if func.training == graph_training_state:
|
||||
return graphed(*user_args)
|
||||
return graphed(*user_args, **user_kwargs)
|
||||
else:
|
||||
return orig_fwd(*user_args)
|
||||
return orig_fwd(*user_args, **user_kwargs)
|
||||
|
||||
return new_fwd
|
||||
|
||||
func.forward = make_graphed_forward(
|
||||
func, func.training, graphed, func.forward
|
||||
) # type: ignore[assignment]
|
||||
)
|
||||
ret.append(func)
|
||||
else:
|
||||
ret.append(graphed)
|
||||
|
@ -28,7 +28,7 @@ class _LazyModule:
|
||||
# NOTE: Add additional used imports here.
|
||||
if TYPE_CHECKING:
|
||||
import onnx
|
||||
import onnx_ir # type: ignore[import-untyped]
|
||||
import onnx_ir # type: ignore[import-untyped, import-not-found]
|
||||
import onnxscript
|
||||
import onnxscript._framework_apis.torch_2_8 as onnxscript_apis
|
||||
|
||||
|
Reference in New Issue
Block a user