mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
That tells whether or not PyTorch was compiled with Arm Compute Library Pull Request resolved: https://github.com/pytorch/pytorch/pull/165678 Approved by: https://github.com/Skylion007, https://github.com/atalman, https://github.com/albanD ghstack dependencies: #165583, #165584, #165676
2786 lines
97 KiB
Python
2786 lines
97 KiB
Python
# ${generated_comment}
|
|
# mypy: disable-error-code="type-arg"
|
|
# mypy: allow-untyped-defs
|
|
# ruff: noqa: F401
|
|
|
|
from collections.abc import Iterable, Iterator, Sequence
|
|
from enum import Enum, IntEnum
|
|
from pathlib import Path
|
|
from types import EllipsisType
|
|
from typing import (
|
|
Any,
|
|
AnyStr,
|
|
Callable,
|
|
Generic,
|
|
IO,
|
|
Literal,
|
|
NamedTuple,
|
|
overload,
|
|
SupportsIndex,
|
|
TypeVar,
|
|
)
|
|
from typing_extensions import ParamSpec, Protocol, runtime_checkable, Self, TypeAlias
|
|
|
|
import numpy
|
|
|
|
import torch
|
|
from torch import inf, SymInt, Tensor
|
|
from torch._C import (
|
|
_acc,
|
|
_aoti,
|
|
_cpu,
|
|
_dynamo,
|
|
_export,
|
|
_functionalization,
|
|
_functorch,
|
|
_lazy,
|
|
_lazy_ts_backend,
|
|
_nn,
|
|
_onnx,
|
|
_VariableFunctions,
|
|
_verbose,
|
|
)
|
|
from torch._prims_common import DeviceLikeType
|
|
from torch.autograd.graph import Node as _Node
|
|
from torch.cuda import _POOL_HANDLE
|
|
from torch.distributed.tensor._op_schema import OpSchema
|
|
from torch.fx.node import Node as FxNode
|
|
from torch.package import PackageExporter
|
|
from torch.storage import TypedStorage, UntypedStorage
|
|
from torch.types import (
|
|
_bool,
|
|
_bytes,
|
|
_complex,
|
|
_device,
|
|
_dispatchkey,
|
|
_dtype,
|
|
_float,
|
|
_int,
|
|
_layout,
|
|
_qscheme,
|
|
_size,
|
|
_str,
|
|
_symsize,
|
|
Device,
|
|
IntLikeType,
|
|
Number,
|
|
Storage,
|
|
)
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
|
|
# This module is defined in torch/csrc/Module.cpp
|
|
|
|
K = TypeVar("K") # noqa: PYI001
|
|
T = TypeVar("T") # noqa: PYI001
|
|
S = TypeVar("S", bound=torch.Tensor) # noqa: PYI001
|
|
P = ParamSpec("P") # noqa: PYI001
|
|
R = TypeVar("R", covariant=True) # return value (always covariant) # noqa: PYI001
|
|
T_co = TypeVar("T_co", covariant=True) # noqa: PYI001
|
|
|
|
@runtime_checkable
|
|
class _NestedSequence(Protocol[T_co]):
|
|
"""A protocol for representing nested sequences.
|
|
|
|
References::
|
|
`numpy._typing._NestedSequence`
|
|
<https://github.com/numpy/numpy/blob/main/numpy/_typing/_nested_sequence.py>
|
|
"""
|
|
|
|
def __len__(self, /) -> _int: ...
|
|
def __getitem__(self, index: _int, /) -> T_co | _NestedSequence[T_co]: ...
|
|
def __contains__(self, x: object, /) -> _bool: ...
|
|
def __iter__(self, /) -> Iterator[T_co | _NestedSequence[T_co]]: ...
|
|
def __reversed__(self, /) -> Iterator[T_co | _NestedSequence[T_co]]: ...
|
|
def count(self, value: Any, /) -> _int: ...
|
|
def index(self, value: Any, /) -> _int: ...
|
|
|
|
# Defined in torch/csrc/Device.cpp
|
|
class device:
|
|
type: str # THPDevice_type
|
|
index: _int # THPDevice_index
|
|
|
|
def __get__(self, instance, owner=None) -> device: ...
|
|
|
|
# THPDevice_pynew
|
|
@overload
|
|
def __init__(self, device: DeviceLikeType) -> None: ...
|
|
@overload
|
|
def __init__(self, type: str, index: _int) -> None: ...
|
|
|
|
# Uncomment if we ever make torch.device a decorator
|
|
# def __call__(self, func: T) -> T: ...
|
|
|
|
def __enter__(self) -> Self: ...
|
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
|
|
def __reduce__(self) -> tuple[Any, ...]: ... # THPDevice_reduce
|
|
|
|
# Defined in torch/csrc/Stream.cpp
|
|
class Stream:
|
|
stream_id: _int # Stream id
|
|
device_index: _int
|
|
device_type: _int
|
|
|
|
device: _device # The device of the stream
|
|
|
|
@overload
|
|
def __new__(
|
|
cls,
|
|
device: DeviceLikeType | None = None,
|
|
*,
|
|
priority: _int = 0,
|
|
) -> Self: ...
|
|
@overload
|
|
def __new__(
|
|
cls,
|
|
stream_id: _int,
|
|
device_index: _int,
|
|
device_type: _int,
|
|
*,
|
|
priority: _int = 0,
|
|
) -> Self: ...
|
|
def query(self) -> _bool: ...
|
|
def synchronize(self) -> None: ...
|
|
def wait_event(self, event: Event) -> None: ...
|
|
def wait_stream(self, other: Stream) -> None: ...
|
|
def record_event(self, event: Event | None = None) -> Event: ...
|
|
def __hash__(self) -> _int: ...
|
|
def __eq__(self, other: object) -> _bool: ...
|
|
def __enter__(self) -> Self: ...
|
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
|
|
|
|
# Defined in torch/csrc/Event.cpp
|
|
class Event:
|
|
device: _device # The device of the Event
|
|
event_id: _int # The raw event created by device backend
|
|
|
|
def __new__(
|
|
cls,
|
|
device: DeviceLikeType | None = None,
|
|
*,
|
|
enable_timing: _bool = False,
|
|
blocking: _bool = False,
|
|
interprocess: _bool = False,
|
|
) -> Self: ...
|
|
@classmethod
|
|
def from_ipc_handle(cls, device: _device, ipc_handle: bytes) -> Event: ...
|
|
def record(self, stream: Stream | None = None) -> None: ...
|
|
def wait(self, stream: Stream | None = None) -> None: ...
|
|
def query(self) -> _bool: ...
|
|
def elapsed_time(self, other: Event) -> _float: ...
|
|
def synchronize(self) -> None: ...
|
|
def ipc_handle(self) -> bytes: ...
|
|
|
|
# Defined in torch/csrc/Size.cpp
|
|
class Size(tuple[_int, ...]):
|
|
# TODO: __reduce__
|
|
|
|
@overload
|
|
def __getitem__(self: Size, key: SupportsIndex, /) -> _int: ...
|
|
@overload
|
|
def __getitem__(self: Size, key: slice, /) -> Size: ...
|
|
# Note: torch.Size does not support adding non-integer tuples.
|
|
def __add__(self, other: tuple[_int, ...], /) -> Size: ... # type: ignore[override]
|
|
def __radd__(self: Size, other: tuple[_int, ...], /) -> Size: ...
|
|
def __mul__(self, other: SupportsIndex, /) -> Size: ...
|
|
def __rmul__(self, other: SupportsIndex, /) -> Size: ...
|
|
def numel(self: Size, /) -> _int: ...
|
|
|
|
# Defined in torch/csrc/Dtype.cpp
|
|
class dtype:
|
|
# TODO: __reduce__
|
|
is_floating_point: _bool
|
|
is_complex: _bool
|
|
is_signed: _bool
|
|
itemsize: _int
|
|
def to_real(self) -> dtype: ...
|
|
def to_complex(self) -> dtype: ...
|
|
|
|
# Defined in torch/csrc/TypeInfo.cpp
|
|
class iinfo:
|
|
bits: _int
|
|
min: _int
|
|
max: _int
|
|
dtype: str
|
|
|
|
def __init__(self, dtype: _dtype) -> None: ...
|
|
|
|
class finfo:
|
|
bits: _int
|
|
min: _float
|
|
max: _float
|
|
eps: _float
|
|
tiny: _float
|
|
smallest_normal: _float
|
|
resolution: _float
|
|
dtype: str
|
|
|
|
@overload
|
|
def __init__(self, dtype: _dtype) -> None: ...
|
|
@overload
|
|
def __init__(self) -> None: ...
|
|
|
|
${dtype_class_hints}
|
|
|
|
# Defined in torch/csrc/Layout.cpp
|
|
class layout: ...
|
|
|
|
# Defined in torch/csrc/utils/disable_torch_function.cpp
|
|
def DisableTorchFunction(): ...
|
|
def DisableTorchFunctionSubclass(): ...
|
|
|
|
# Defined in torch/csrc/utils/tensor_layouts.cpp
|
|
strided: layout = ...
|
|
sparse_coo: layout = ...
|
|
sparse_csr: layout = ...
|
|
sparse_csc: layout = ...
|
|
sparse_bsr: layout = ...
|
|
sparse_bsc: layout = ...
|
|
_mkldnn: layout = ...
|
|
jagged: layout = ...
|
|
|
|
# Defined in torch/csrc/MemoryFormat.cpp
|
|
class memory_format: ...
|
|
|
|
# Defined in torch/csrc/utils/tensor_memoryformats.cpp
|
|
contiguous_format: memory_format = ...
|
|
channels_last: memory_format = ...
|
|
channels_last_3d: memory_format = ...
|
|
preserve_format: memory_format = ...
|
|
|
|
# Defined in torch/csrc/QScheme.cpp
|
|
class qscheme: ...
|
|
|
|
# Defined in torch/csrc/utils/tensor_qschemes.h
|
|
per_tensor_affine: qscheme = ...
|
|
per_channel_affine: qscheme = ...
|
|
per_tensor_symmetric: qscheme = ...
|
|
per_channel_symmetric: qscheme = ...
|
|
per_channel_affine_float_qparams: qscheme = ...
|
|
|
|
# Defined in torch/csrc/autograd/python_function.cpp
|
|
class _FunctionBase:
|
|
saved_tensors: tuple[Tensor]
|
|
_raw_saved_tensors: tuple[Any]
|
|
next_functions: tuple[tuple[Any, _int], ...]
|
|
needs_input_grad: tuple[_bool]
|
|
metadata: dict
|
|
_materialize_non_diff_grads: _bool
|
|
# skip adding type hints for the fields that have wrappers defined
|
|
# in torch/autograd/function.py
|
|
|
|
# Defined in torch/csrc/autograd/python_legacy_variable.cpp
|
|
class _LegacyVariableBase(Tensor): # inherits from Tensor to appease mypy
|
|
def __init__(
|
|
self,
|
|
data: Tensor | None = ...,
|
|
requires_grad: _bool | None = ...,
|
|
volatile: _bool | None = ...,
|
|
_grad_fn: _FunctionBase | None = ...,
|
|
) -> None: ...
|
|
|
|
# Defined in torch/csrc/jit/python/init.cpp
|
|
class IODescriptor: ...
|
|
class JITException(Exception): ...
|
|
|
|
class Future(Generic[T]):
|
|
def __init__(self, devices: list[device]) -> None: ...
|
|
def done(self) -> _bool: ...
|
|
def value(self) -> T: ...
|
|
def wait(self) -> T: ...
|
|
def add_done_callback(self, callback: Callable) -> None: ...
|
|
def then(self, callback: Callable) -> Future[T]: ...
|
|
def set_result(self, result: T) -> None: ...
|
|
def _set_unwrap_func(self, callback: Callable) -> None: ...
|
|
|
|
class _Await:
|
|
def __init__(self) -> None: ...
|
|
def fn(self) -> Callable: ...
|
|
def args(self) -> tuple[Any, ...]: ...
|
|
def is_nowait(self) -> _bool: ...
|
|
|
|
def _jit_set_num_profiled_runs(num: _size) -> _size: ...
|
|
|
|
# Defined in torch/csrc/jit/passes/mobile_optimizer_type.h
|
|
class _MobileOptimizerType: ...
|
|
|
|
CONV_BN_FUSION: _MobileOptimizerType
|
|
INSERT_FOLD_PREPACK_OPS: _MobileOptimizerType
|
|
REMOVE_DROPOUT: _MobileOptimizerType
|
|
FUSE_ADD_RELU: _MobileOptimizerType
|
|
HOIST_CONV_PACKED_PARAMS: _MobileOptimizerType
|
|
VULKAN_AUTOMATIC_GPU_TRANSFER: _MobileOptimizerType
|
|
|
|
def fork(*args: Any, **kwargs: Any) -> Future: ...
|
|
def wait(fut: Future) -> Any: ...
|
|
def _awaitable(*args: Any, **kwargs: Any) -> _Await: ...
|
|
def _awaitable_wait(aw: _Await) -> Any: ...
|
|
def _awaitable_nowait(x: Any) -> _Await: ...
|
|
def _collect_all(futures: list[Future]) -> Future: ...
|
|
def _set_print_stack_traces_on_fatal_signal(print: _bool) -> None: ...
|
|
def unify_type_list(types: list[JitType]) -> JitType: ...
|
|
def _freeze_module(
|
|
module: ScriptModule,
|
|
preserved_attrs: list[str] = ...,
|
|
freeze_interfaces: _bool = True,
|
|
preserveParameters: _bool = True,
|
|
) -> ScriptModule: ...
|
|
def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ...
|
|
def _jit_pass_optimize_for_inference(
|
|
module: torch.jit.ScriptModule,
|
|
other_methods: list[str] = ...,
|
|
) -> None: ...
|
|
def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
|
|
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
|
|
def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ...
|
|
def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ...
|
|
def _jit_pass_concat_frozen_linear(graph: Graph): ...
|
|
def _jit_pass_convert_frozen_ops_to_mkldnn(graph: Graph): ...
|
|
def _jit_pass_transpose_frozen_linear(graph: Graph): ...
|
|
def _jit_pass_remove_dropout(module: torch.jit.ScriptModule): ...
|
|
def _is_tracing() -> _bool: ...
|
|
def _jit_init() -> _bool: ...
|
|
def _jit_flatten(arg: Any) -> tuple[list[Tensor], IODescriptor]: ...
|
|
def _jit_unflatten(vars: list[Tensor], desc: IODescriptor) -> Any: ...
|
|
def _jit_get_operation(op_name: str) -> tuple[Callable, list[str]]: ...
|
|
def _get_operation_overload(
|
|
op_name: str,
|
|
op_overload_name: str,
|
|
) -> tuple[Callable, Callable, list[Any]]: ...
|
|
def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ...
|
|
def _jit_pass_optimize_for_mobile(
|
|
module: torch.jit.ScriptModule,
|
|
optimization_blocklist: set[_MobileOptimizerType],
|
|
preserved_methods: list[AnyStr],
|
|
) -> torch.jit.ScriptModule: ...
|
|
def _clone_module_with_class(
|
|
module: torch.jit.ScriptModule,
|
|
ignored_methods: list[AnyStr],
|
|
ignored_attributes: list[AnyStr],
|
|
) -> torch.jit.ScriptModule: ...
|
|
def _jit_pass_vulkan_optimize_for_mobile(
|
|
module: torch.jit.ScriptModule,
|
|
optimization_blocklist: set[_MobileOptimizerType],
|
|
preserved_methods: list[AnyStr],
|
|
) -> torch.jit.ScriptModule: ...
|
|
def _jit_pass_metal_optimize_for_mobile(
|
|
module: torch.jit.ScriptModule,
|
|
preserved_methods: list[AnyStr],
|
|
) -> torch.jit.ScriptModule: ...
|
|
def _jit_pass_inline(Graph) -> None: ...
|
|
def _jit_pass_constant_propagation(Graph) -> None: ...
|
|
def _jit_pass_propagate_shapes_on_graph(Graph) -> None: ...
|
|
def _jit_register_decomposition_for_schema(schema: FunctionSchema, Graph) -> None: ...
|
|
def _jit_erase_non_input_shape_information(Graph) -> None: ...
|
|
def _jit_get_schemas_for_operator(name: str) -> list[FunctionSchema]: ...
|
|
def _jit_get_all_schemas() -> list[FunctionSchema]: ...
|
|
def _jit_check_alias_annotation(
|
|
g: Graph,
|
|
args: tuple[Any, ...],
|
|
unqualified_op_name: str,
|
|
): ...
|
|
def _jit_can_fuse_on_cpu() -> _bool: ...
|
|
def _jit_can_fuse_on_gpu() -> _bool: ...
|
|
def _jit_can_fuse_on_cpu_legacy() -> _bool: ...
|
|
def _debug_get_fusion_group_inlining() -> _bool: ...
|
|
def _debug_set_fusion_group_inlining(enable: _bool): ...
|
|
def _jit_texpr_fuser_enabled() -> _bool: ...
|
|
def _jit_nvfuser_enabled() -> _bool: ...
|
|
def _jit_llga_enabled() -> _bool: ...
|
|
def _jit_set_llga_enabled(enable: _bool): ...
|
|
def _llvm_enabled() -> _bool: ...
|
|
def _jit_override_can_fuse_on_cpu(override: _bool): ...
|
|
def _jit_override_can_fuse_on_gpu(override: _bool): ...
|
|
def _jit_override_can_fuse_on_cpu_legacy(override: _bool): ...
|
|
def _jit_set_symbolic_shapes_test_mode(override: _bool): ...
|
|
def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ...
|
|
def _jit_set_texpr_fuser_enabled(enable: _bool): ...
|
|
def _jit_set_te_must_use_llvm_cpu(use_llvm: _bool): ...
|
|
def _jit_set_nvfuser_enabled(enable: _bool) -> _bool: ...
|
|
def _jit_cat_wo_conditionals(optimize_cat: _bool): ...
|
|
def _jit_opt_conditionals(opt_conds: _bool): ...
|
|
def _jit_pass_canonicalize(graph: Graph, keep_unique_names: _bool = True): ...
|
|
def _jit_pass_erase_shape_information(graph: Graph): ...
|
|
def _jit_pass_fold_convbn(module: torch.jit.ScriptModule): ...
|
|
def _jit_pass_insert_observers(
|
|
module: torch.jit.ScriptModule,
|
|
method_name: str,
|
|
qconfig_dict: dict[str, Any],
|
|
inplace: _bool,
|
|
quant_type: _int,
|
|
): ...
|
|
def _jit_pass_insert_quant_dequant(
|
|
module: torch.jit.ScriptModule,
|
|
method_name: str,
|
|
inplace: _bool,
|
|
debug: _bool,
|
|
quant_type: _int,
|
|
): ...
|
|
def _jit_pass_insert_quant_dequant_for_ondevice_ptq(
|
|
module: torch.jit.ScriptModule,
|
|
method_name: str,
|
|
inplace: _bool,
|
|
debug: _bool,
|
|
quant_type: _int,
|
|
): ...
|
|
def _jit_pass_quant_finalize(
|
|
module: torch.jit.ScriptModule,
|
|
quant_type: _int,
|
|
preserved_attrs: Sequence[str],
|
|
): ...
|
|
def _jit_pass_quant_finalize_for_ondevice_ptq(
|
|
module: torch.jit.ScriptModule,
|
|
quant_type: _int,
|
|
method_name: str,
|
|
): ...
|
|
def _jit_pass_insert_observer_method_for_ondevice_ptq(
|
|
module: torch.jit.ScriptModule,
|
|
method_name: str,
|
|
qconfig_dict: dict[str, Any],
|
|
inplace: _bool,
|
|
quant_type: _int,
|
|
): ...
|
|
def _jit_set_profiling_executor(profiling_flag: _bool) -> _bool: ...
|
|
def _jit_set_profiling_mode(profiling_flag: _bool) -> _bool: ...
|
|
def _jit_set_fusion_strategy(
|
|
strategy: list[tuple[str, _int]],
|
|
) -> list[tuple[str, _int]]: ...
|
|
def _jit_try_infer_type(obj: Any) -> InferredType: ...
|
|
def _jit_get_trigger_value(trigger_name: str) -> _int: ...
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]]
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
# and torch/csrc/jit/python/init.cpp
|
|
def _maybe_call_torch_function_for_op_packet(
|
|
op_overload_packet: Any,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> Any: ...
|
|
def _check_schema_allow_fake_script_object(
|
|
schema: FunctionSchema,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> _bool: ...
|
|
def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ...
|
|
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
|
|
def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
|
|
def _jit_assert_is_instance(obj: Any, type: JitType): ...
|
|
def _jit_clear_class_registry() -> None: ...
|
|
def _jit_set_emit_hooks(
|
|
ModuleHook: Callable | None,
|
|
FunctionHook: Callable | None,
|
|
) -> None: ...
|
|
def _jit_get_emit_hooks() -> tuple[Callable, Callable]: ...
|
|
def _load_for_lite_interpreter(
|
|
filename: str | Path,
|
|
map_location: DeviceLikeType | None,
|
|
): ...
|
|
def _load_for_lite_interpreter_from_buffer(
|
|
buffer: IO[bytes],
|
|
map_location: DeviceLikeType | None,
|
|
): ...
|
|
def _export_operator_list(module: LiteScriptModule): ...
|
|
def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ...
|
|
def _get_model_bytecode_version(filename: str | Path) -> _int: ...
|
|
def _get_model_bytecode_version_from_buffer(buffer: IO[bytes]) -> _int: ...
|
|
def _backport_for_mobile(
|
|
filename_input: str | Path,
|
|
filename_output: str | Path,
|
|
to_version: _int,
|
|
) -> None: ...
|
|
def _backport_for_mobile_from_buffer(
|
|
buffer: IO[bytes],
|
|
filename_output: str | Path,
|
|
to_version: _int,
|
|
) -> None: ...
|
|
def _backport_for_mobile_to_buffer(
|
|
filename_input: str | Path,
|
|
to_version: _int,
|
|
) -> bytes: ...
|
|
def _backport_for_mobile_from_buffer_to_buffer(
|
|
buffer: IO[bytes],
|
|
to_version: _int,
|
|
) -> bytes: ...
|
|
def _get_model_ops_and_info(filename: str | Path): ...
|
|
def _get_model_ops_and_info_from_buffer(buffer: IO[bytes]): ...
|
|
def _get_mobile_model_contained_types(filename: str | Path): ...
|
|
def _get_mobile_model_contained_types_from_buffer(buffer: IO[bytes]): ...
|
|
def _logging_set_logger(logger: LoggerBase) -> LoggerBase: ...
|
|
def _get_graph_executor_optimize(optimize: _bool | None = None) -> _bool: ...
|
|
def _set_graph_executor_optimize(optimize: _bool): ...
|
|
def _export_opnames(module: ScriptModule) -> list[str]: ...
|
|
def _create_function_from_trace(
|
|
qualname: str,
|
|
func: Callable[..., Any],
|
|
input_tuple: tuple[Any, ...],
|
|
var_lookup_fn: Callable[[Tensor], str],
|
|
strict: _bool,
|
|
force_outplace: _bool,
|
|
argument_names: list[str],
|
|
) -> tuple[Graph, Stack]: ...
|
|
def _create_function_from_trace_with_dict(
|
|
qualname: str,
|
|
func: Callable[..., Any],
|
|
input_dict: dict[str, Any],
|
|
var_lookup_fn: Callable[[Tensor], str],
|
|
strict: _bool,
|
|
force_outplace: _bool,
|
|
argument_names: list[str],
|
|
) -> tuple[Graph, Stack]: ...
|
|
def _jit_is_script_object(obj: Any) -> _bool: ...
|
|
def _last_executed_optimized_graph() -> Graph: ...
|
|
def parse_type_comment(comment: str) -> Decl: ...
|
|
def _get_upgraders_map_size() -> _int: ...
|
|
def _get_upgraders_entry_map() -> dict[str, str]: ...
|
|
def _dump_upgraders_map() -> dict[str, str]: ...
|
|
def _test_only_populate_upgraders(content: dict[str, str]) -> None: ...
|
|
def _test_only_remove_upgraders(content: dict[str, str]) -> None: ...
|
|
def merge_type_from_type_comment(
|
|
decl: Decl,
|
|
type_annotation_decl: Decl,
|
|
is_method: _bool,
|
|
) -> Decl: ...
|
|
def parse_ir(input: str, parse_tensor_constants: _bool = False) -> Graph: ...
|
|
def parse_schema(schema: str) -> FunctionSchema: ...
|
|
def get_device(input: Tensor) -> _int: ...
|
|
def _resolve_type_from_object(
|
|
obj: Any,
|
|
range: SourceRange,
|
|
rcb: ResolutionCallback,
|
|
) -> JitType: ...
|
|
def _create_module_with_type(ty: JitType) -> ScriptModule: ...
|
|
def _create_object_with_type(ty: ClassType) -> ScriptObject: ...
|
|
def _run_emit_module_hook(m: ScriptModule): ...
|
|
def _replace_overloaded_method_decl(
|
|
overload_decl: Decl,
|
|
implementation_def: Def,
|
|
new_name: str,
|
|
) -> Def: ...
|
|
def _jit_pass_lower_all_tuples(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_set_dynamic_input_shape(
|
|
graph: Graph,
|
|
dynamic_axes: dict[str, dict[_int, str]],
|
|
input_names: list[str],
|
|
) -> None: ...
|
|
def _jit_pass_onnx_graph_shape_type_inference(
|
|
graph: Graph,
|
|
params_dict: dict[str, IValue],
|
|
opset_version: _int,
|
|
) -> None: ...
|
|
def _jit_pass_onnx_assign_output_shape(
|
|
graph: Graph,
|
|
tensors: list[Tensor],
|
|
desc: IODescriptor,
|
|
onnx_shape_inference: _bool,
|
|
is_script: _bool,
|
|
opset_version: _int,
|
|
) -> None: ...
|
|
def _jit_pass_onnx_remove_inplace_ops_for_onnx(
|
|
graph: Graph,
|
|
module: ScriptModule | None = None,
|
|
) -> None: ...
|
|
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
|
|
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
|
|
def _jit_pass_peephole(
|
|
graph: Graph,
|
|
disable_shape_peepholes: _bool = False,
|
|
) -> None: ...
|
|
def _jit_pass_onnx_autograd_function_process(graph: Graph) -> None: ...
|
|
def _jit_pass_fuse_addmm(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_preprocess(graph: Graph) -> None: ...
|
|
def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_remove_print(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_unpack_quantized_weights(
|
|
graph: Graph,
|
|
paramsDict: dict[str, IValue],
|
|
) -> dict[str, IValue]: ...
|
|
def _jit_pass_onnx_quantization_insert_permutes(
|
|
graph: Graph,
|
|
paramsDict: dict[str, IValue],
|
|
) -> dict[str, IValue]: ...
|
|
def _jit_pass_custom_pattern_based_rewrite_graph(
|
|
pattern: str,
|
|
fused_node_name: str,
|
|
graph: Graph,
|
|
) -> None: ...
|
|
def _jit_onnx_list_model_parameters(
|
|
module: ScriptModule,
|
|
) -> tuple[ScriptModule, list[IValue]]: ...
|
|
def _jit_pass_erase_number_types(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_lint(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx(
|
|
graph: Graph,
|
|
_jit_pass_onnx: _onnx.OperatorExportTypes,
|
|
) -> Graph: ...
|
|
def _jit_pass_onnx_scalar_type_analysis(
|
|
graph: Graph,
|
|
lowprecision_cast: _bool,
|
|
opset_version: _int,
|
|
) -> None: ...
|
|
def _jit_pass_onnx_peephole(
|
|
graph: Graph,
|
|
opset_version: _int,
|
|
fixed_batch_size: _bool,
|
|
) -> None: ...
|
|
def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_function_extraction(
|
|
graph: Graph,
|
|
module_names: set[str],
|
|
param_names: list[str],
|
|
) -> dict[Node, dict[str, str]]: ...
|
|
def _jit_pass_onnx_clear_scope_records() -> None: ...
|
|
def _jit_pass_onnx_track_scope_attributes(
|
|
graph: Graph,
|
|
onnx_attrs: dict[str, Any],
|
|
) -> None: ...
|
|
def _jit_is_onnx_log_enabled() -> _bool: ...
|
|
def _jit_set_onnx_log_enabled(enabled: _bool) -> None: ...
|
|
def _jit_set_onnx_log_output_stream(stream_name: str) -> None: ...
|
|
def _jit_onnx_log(*args: Any) -> None: ...
|
|
def _jit_pass_lower_graph(graph: Graph, m: Module) -> tuple[Graph, list[IValue]]: ...
|
|
def _jit_pass_inline_fork_wait(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_deduplicate_initializers(
|
|
graph: Graph,
|
|
params_dict: dict[str, IValue],
|
|
is_train: _bool,
|
|
) -> dict[str, IValue]: ...
|
|
def _jit_pass_onnx_eval_peephole(
|
|
graph: Graph,
|
|
paramsDict: dict[str, IValue],
|
|
) -> dict[str, IValue]: ...
|
|
def _jit_pass_onnx_constant_fold(
|
|
graph: Graph,
|
|
paramsDict: dict[str, IValue],
|
|
opset_version: _int,
|
|
) -> dict[str, IValue]: ...
|
|
def _jit_pass_onnx_eliminate_unused_items(
|
|
graph: Graph,
|
|
paramsDict: dict[str, IValue],
|
|
) -> dict[str, IValue]: ...
|
|
def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ...
|
|
def _jit_pass_filter_non_tensor_arguments(
|
|
params: dict[str, IValue],
|
|
) -> dict[str, Tensor]: ...
|
|
def _jit_decay_packed_param_input_types(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_node_shape_type_inference(
|
|
n: Node,
|
|
paramsDict: dict[str, IValue],
|
|
opset_version: _int,
|
|
) -> None: ...
|
|
def _jit_onnx_convert_pattern_from_subblock(
|
|
block: Block,
|
|
n: Node,
|
|
env: dict[Value, Value],
|
|
values_in_env: set[Value],
|
|
) -> list[Value]: ...
|
|
def _jit_pass_onnx_block(
|
|
old_block: Block,
|
|
new_block: Block,
|
|
operator_export_type: _onnx.OperatorExportTypes,
|
|
env: dict[Value, Value],
|
|
values_in_env: set[Value],
|
|
is_sub_block: _bool,
|
|
) -> dict[Value, Value]: ...
|
|
def _jit_pass_onnx_assign_scoped_names_for_node_and_value(graph: Graph) -> None: ...
|
|
def _jit_pass_fixup_onnx_controlflow_node(
|
|
n: Node,
|
|
opset_version: _int,
|
|
) -> list[Value]: ...
|
|
def _jit_onnx_create_full_scope_name(class_name: str, variable_name: str) -> str: ...
|
|
def _compile_graph_to_code_table(name: str, graph: Graph) -> IValue: ...
|
|
def _generate_upgraders_graph() -> dict[str, Graph]: ...
|
|
def _calculate_package_version_based_on_upgraders(val: _bool): ...
|
|
def _get_version_calculator_flag() -> _bool: ...
|
|
def _jit_script_interface_compile(
|
|
name: str,
|
|
class_def: ClassDef,
|
|
rcb: ResolutionCallback,
|
|
is_module: _bool,
|
|
): ...
|
|
def _jit_script_compile_overload(
|
|
qualname: str,
|
|
overload_decl: Decl,
|
|
implementation_def: Def,
|
|
rcb: ResolutionCallback,
|
|
implementation_defaults: dict[str, Any],
|
|
signature: Any,
|
|
): ...
|
|
def _jit_script_compile(
|
|
qual_name: str,
|
|
definition: Def,
|
|
rcb: ResolutionCallback,
|
|
defaults: dict[str, Any],
|
|
): ...
|
|
def _jit_script_class_compile(
|
|
qual_name: str,
|
|
definition: ClassDef,
|
|
defaults: dict[str, dict[str, Any]],
|
|
rcb: ResolutionCallback,
|
|
): ...
|
|
def _parse_source_def(src: str) -> Def: ...
|
|
def import_ir_module(
|
|
cu: CompilationUnit,
|
|
filename: str | Path,
|
|
map_location: DeviceLikeType | None,
|
|
extra_files: dict[str, Any],
|
|
) -> ScriptModule: ...
|
|
def import_ir_module_from_buffer(
|
|
cu: CompilationUnit,
|
|
buffer: IO[bytes],
|
|
map_location: DeviceLikeType | None,
|
|
extra_files: dict[str, Any],
|
|
) -> ScriptModule: ...
|
|
def _import_ir_module_from_package(
|
|
cu: CompilationUnit,
|
|
reader: PyTorchFileReader,
|
|
storage_context: DeserializationStorageContext,
|
|
map_location: DeviceLikeType | None,
|
|
ts_id: str,
|
|
) -> ScriptModule: ...
|
|
def _assign_output_shapes(graph: Graph, inputs: list[Tensor]) -> Graph: ...
|
|
def _check_onnx_proto(proto: str) -> None: ...
|
|
def _propagate_and_assign_input_shapes(
|
|
graph: Graph,
|
|
inputs: tuple[Tensor, ...],
|
|
param_count_list: list[_int],
|
|
with_grad: _bool,
|
|
propagate: _bool,
|
|
) -> Graph: ...
|
|
|
|
# Defined in torch/csrc/jit/runtime/graph_executor.h
|
|
class GraphExecutorState: ...
|
|
|
|
# Defined in torch/torch/csrc/jit/ir/alias_analysis.h
|
|
class AliasDb: ...
|
|
|
|
class _InsertPoint:
|
|
def __enter__(self) -> None: ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
# Defined in torch/csrc/jit/ir/ir.h
|
|
class Use:
|
|
@property
|
|
def user(self) -> Node: ...
|
|
@property
|
|
def offset(self) -> _int: ...
|
|
def isAfter(self, other: Use) -> _bool: ...
|
|
|
|
# Defined in torch/csrc/jit/ir/ir.h
|
|
class Value:
|
|
def type(self) -> JitType: ...
|
|
def setType(self, t: JitType) -> Value: ...
|
|
def setTypeAs(self, other: Value) -> Value: ...
|
|
def inferTypeFrom(self, t: Tensor) -> None: ...
|
|
def debugName(self) -> str: ...
|
|
def setDebugName(self, name: str) -> None: ...
|
|
def unique(self) -> _int: ...
|
|
def offset(self) -> _int: ...
|
|
def node(self) -> Node: ...
|
|
def uses(self) -> list[Use]: ...
|
|
def replaceAllUsesWith(self, val: Value) -> None: ...
|
|
def replaceAllUsesAfterNodeWith(self, node: Node, val: Value) -> None: ...
|
|
def requires_grad(self) -> _bool: ...
|
|
def requiresGrad(self) -> _bool: ...
|
|
def copyMetadata(self, other: Value) -> Value: ...
|
|
def isCompleteTensor(self) -> _bool: ...
|
|
def toIValue(self) -> IValue: ...
|
|
|
|
# Defined in torch/csrc/jit/ir/ir.h
|
|
class Block:
|
|
def inputs(self) -> Iterator[Value]: ...
|
|
def outputs(self) -> Iterator[Value]: ...
|
|
def nodes(self) -> Iterator[Node]: ...
|
|
def paramNode(self) -> Node: ...
|
|
def returnNode(self) -> Node: ...
|
|
def owningNode(self) -> Node: ...
|
|
def registerOutput(self, n: Value) -> _int: ...
|
|
def addNode(self, name: str, inputs: Sequence[Value]) -> Node: ...
|
|
|
|
# Defined in torch/csrc/jit/ir/ir.h
|
|
class Node:
|
|
def __getitem__(self, key: str) -> Any: ...
|
|
def schema(self) -> str: ...
|
|
def input(self) -> Value: ...
|
|
def inputs(self) -> Iterator[Value]: ...
|
|
def inputsAt(self, idx: _int) -> Value: ...
|
|
def inputsSize(self) -> _int: ...
|
|
def output(self) -> Value: ...
|
|
def outputs(self) -> Iterator[Value]: ...
|
|
def outputsAt(self, idx: _int) -> Value: ...
|
|
def outputsSize(self) -> _int: ...
|
|
def hasMultipleOutputs(self) -> _bool: ...
|
|
def blocks(self) -> list[Block]: ...
|
|
def addBlock(self) -> Block: ...
|
|
def mustBeNone(self) -> _bool: ...
|
|
def matches(self, pattern: str) -> _bool: ...
|
|
def kind(self) -> str: ...
|
|
def kindOf(self, name: str) -> str: ...
|
|
def addInput(self, name: str) -> Value: ...
|
|
def replaceInput(self, i: _int, newValue: Value) -> Value: ...
|
|
def replaceInputWith(self, from_: Value, to: Value) -> None: ...
|
|
def replaceAllUsesWith(self, n: Node) -> None: ...
|
|
def insertBefore(self, n: Node) -> Node: ...
|
|
def insertAfter(self, n: Node) -> Node: ...
|
|
def isBefore(self, n: Node) -> _bool: ...
|
|
def isAfter(self, n: Node) -> _bool: ...
|
|
def moveBefore(self, n: Node) -> None: ...
|
|
def moveAfter(self, n: Node) -> None: ...
|
|
def removeInput(self, i: _int) -> None: ...
|
|
def removeAllInputs(self, i: _int) -> None: ...
|
|
def hasUses(self) -> _bool: ...
|
|
def eraseOutput(self, i: _int) -> None: ...
|
|
def addOutput(self) -> Value: ...
|
|
def scopeName(self) -> str: ...
|
|
def isNondeterministic(self) -> _bool: ...
|
|
def copyAttributes(self, rhs: Node) -> Node: ...
|
|
def copyMetadata(self, rhs: Node) -> Node: ...
|
|
def hasAttributes(self) -> _bool: ...
|
|
def hasAttribute(self, name: str) -> _bool: ...
|
|
def removeAttribute(self, attr: str) -> Node: ...
|
|
def namedInput(self, name: str) -> Value: ...
|
|
def sourceRange(self) -> SourceRange: ...
|
|
def owningBlock(self) -> Block: ...
|
|
def findNode(self, kind: str, recurse: _bool = True) -> Node: ...
|
|
def findAllNodes(self, kind: str, recurse: _bool = True) -> list[Node]: ...
|
|
def getModuleHierarchy(self) -> str: ...
|
|
def prev(self) -> Node: ...
|
|
def destroy(self) -> None: ...
|
|
def attributeNames(self) -> list[str]: ...
|
|
|
|
# Accessors for attributes as types.
|
|
def f(self, name: str) -> _float: ...
|
|
def f_(self, name: str, val: _float) -> Node: ...
|
|
def fs(self, name: str) -> list[_float]: ...
|
|
def fs_(self, name: str, val: list[_float]) -> Node: ...
|
|
def c(self, name: str) -> complex: ...
|
|
def c_(self, name: str, val: complex) -> Node: ...
|
|
def s(self, name: str) -> str: ...
|
|
def s_(self, name: str, val: str) -> Node: ...
|
|
def ss(self, name: str) -> list[str]: ...
|
|
def ss_(self, name: str, val: list[str]) -> Node: ...
|
|
def i(self, name: str) -> _int: ...
|
|
def i_(self, name: str, val: _int) -> Node: ...
|
|
# Cannot define "is" like this because it's a reserved keyword in python.
|
|
# def is(self, name: str) -> List[_int]: ...
|
|
# def is_(self, name: str, val: List[_int]) -> Node: ...
|
|
def g(self, name: str) -> Graph: ...
|
|
def g_(self, name: str, val: Graph) -> Node: ...
|
|
def gs(self, name: str) -> list[Graph]: ...
|
|
def gs_(self, name: str, val: list[Graph]) -> Node: ...
|
|
def ival(self, name: str) -> IValue: ...
|
|
def ival_(self, name: str, val: IValue) -> Node: ...
|
|
def t(self, name: str) -> Tensor: ...
|
|
def t_(self, name: str, val: Tensor) -> Node: ...
|
|
def ts(self, name: str) -> list[Tensor]: ...
|
|
def ts_(self, name: str, val: list[Tensor]) -> Node: ...
|
|
def ty(self, name: str) -> JitType: ...
|
|
def ty_(self, name: str, val: JitType) -> Node: ...
|
|
def tys(self, name: str) -> list[JitType]: ...
|
|
def tys_(self, name: str, val: list[JitType]) -> Node: ...
|
|
|
|
# Defined in torch/torch/csrc/jit/ir/ir.h
|
|
class Graph:
|
|
def inputs(self) -> Iterator[Value]: ...
|
|
def outputs(self) -> Iterator[Value]: ...
|
|
def nodes(self) -> Iterator[Node]: ...
|
|
def param_node(self) -> Node: ...
|
|
def return_node(self) -> Node: ...
|
|
def addInput(self, name: str = "") -> Value: ...
|
|
def eraseInput(self, i: _int) -> None: ...
|
|
def registerOutput(self, n: Value) -> _int: ...
|
|
def eraseOutput(self, i: _int) -> None: ...
|
|
def create(self, name: str, args, num_outputs: _int) -> Node: ...
|
|
def appendNode(self, n: Node) -> Node: ...
|
|
def prependNode(self, n: Node) -> Node: ...
|
|
def insertNode(self, n: Node) -> Node: ...
|
|
def block(self) -> Block: ...
|
|
def lint(self) -> None: ...
|
|
def alias_db(self) -> AliasDb: ...
|
|
def setInsertPoint(self, n: Block | Node) -> None: ...
|
|
def insert_point_guard(self, n: Block | Node) -> _InsertPoint: ...
|
|
def insertPoint(self) -> Node: ...
|
|
def insertGraph(self, callee: Graph, inputs: list[Value]) -> list[Value]: ...
|
|
def makeMultiOutputIntoTuple(self) -> None: ...
|
|
def copy(self) -> Graph: ...
|
|
|
|
# Defined in torch/aten/src/ATen/core/alias_info.h
|
|
class AliasInfo:
|
|
is_write: _bool
|
|
before_set: set[str]
|
|
after_set: set[str]
|
|
def __init__(
|
|
self,
|
|
is_write: _bool,
|
|
before_set: set[str],
|
|
after_set: set[str],
|
|
) -> None: ...
|
|
|
|
# Defined in torch/aten/src/ATen/core/function_schema.h
|
|
class Argument:
|
|
name: str
|
|
type: JitType
|
|
default_value: Any | None
|
|
def has_default_value(self) -> _bool: ...
|
|
kwarg_only: _bool
|
|
is_out: _bool
|
|
alias_info: AliasInfo | None
|
|
is_write: _bool
|
|
real_type: JitType
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
type: JitType,
|
|
N: _int | None,
|
|
defualt_value: Any | None,
|
|
kwarg_only: _bool,
|
|
alias_info: AliasInfo | None,
|
|
) -> None: ...
|
|
|
|
class FunctionSchema:
|
|
arguments: list[Argument]
|
|
returns: list[Argument]
|
|
name: str
|
|
overload_name: str
|
|
is_mutable: _bool
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
overload_name: str,
|
|
arguments: list[Argument],
|
|
returns: list[Argument],
|
|
is_vararg: _bool,
|
|
is_varret: _bool,
|
|
) -> None: ...
|
|
def _is_view_op(self) -> _bool: ...
|
|
|
|
class _UpgraderEntry:
|
|
bumped_at_version: _int
|
|
upgrader_name: str
|
|
old_schema: str
|
|
def __init__(
|
|
self,
|
|
bumped_at_version: _int,
|
|
upgrader_name: str,
|
|
old_schema: str,
|
|
) -> None: ...
|
|
|
|
class _UpgraderRange:
|
|
min_version: _int
|
|
max_version: _int
|
|
|
|
def _get_max_operator_version() -> _int: ...
|
|
def _get_operator_version_map() -> dict[str, list[_UpgraderEntry]]: ...
|
|
def _get_upgrader_ranges(name: str) -> list[_UpgraderRange]: ...
|
|
def _test_only_add_entry_to_op_version(op_name: str, entry: _UpgraderEntry) -> None: ...
|
|
def _test_only_remove_entry_to_op_version(op_name: str) -> None: ...
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
class ScriptModuleSerializer:
|
|
def __init__(self, export_writer: PyTorchFileWriter) -> None: ...
|
|
def serialize(self, model: ScriptModule, script_module_id: _int) -> None: ...
|
|
def write_files(self) -> None: ...
|
|
def storage_context(self) -> SerializationStorageContext: ...
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
class SerializationStorageContext:
|
|
def __init__(self) -> None: ...
|
|
def has_storage(self, storage: Storage) -> _bool: ...
|
|
def get_or_add_storage(self, storage: Storage) -> _int: ...
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
class DeserializationStorageContext:
|
|
def __init__(self) -> None: ...
|
|
def get_storage(self, name: str, dtype: _dtype) -> Tensor: ...
|
|
def has_storage(self, name: str) -> _bool: ...
|
|
def add_storage(self, name: str, tensor: Tensor) -> _int: ...
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
class ConcreteModuleTypeBuilder:
|
|
def __init__(self, obj: Any) -> None: ...
|
|
def set_module_dict(self): ...
|
|
def set_module_list(self): ...
|
|
def set_parameter_list(self): ...
|
|
def set_parameter_dict(self): ...
|
|
def add_attribute(
|
|
self,
|
|
name: str,
|
|
ty: JitType,
|
|
is_param: _bool,
|
|
is_buffer: _bool,
|
|
): ...
|
|
def add_module(self, name: str, meta: ConcreteModuleType): ...
|
|
def add_constant(self, name: str, value: Any): ...
|
|
def add_overload(self, method_name: str, overloaded_method_names: list[str]): ...
|
|
def add_builtin_function(self, name: str, symbol_name: str): ...
|
|
def add_failed_attribute(self, name: str, failure_reason: str): ...
|
|
def add_function_attribute(
|
|
self,
|
|
name: str,
|
|
ty: JitType,
|
|
func: Callable[..., Any],
|
|
): ...
|
|
def add_ignored_attribute(self, name: str): ...
|
|
def add_ignored_attributes(self, names: list[str]): ...
|
|
def add_forward_hook(self, hook: Callable[..., Any]): ...
|
|
def add_forward_pre_hook(self, pre_hook: Callable[..., Any]): ...
|
|
|
|
class ConcreteModuleType:
|
|
def get_constants(self) -> dict[str, Any]: ...
|
|
def equals(self, other: ConcreteModuleType) -> _bool: ...
|
|
@staticmethod
|
|
def from_jit_type(ty: JitType) -> ConcreteModuleType: ...
|
|
|
|
class CallStack:
|
|
def __init__(self, name: str, range: SourceRange) -> None: ...
|
|
|
|
class ErrorReport:
|
|
def __init__(self, range: SourceRange) -> None: ...
|
|
def what(self) -> str: ...
|
|
@staticmethod
|
|
def call_stack() -> str: ...
|
|
|
|
class CompilationUnit:
|
|
def __init__(self, lang: str = ..., _frames_up: _int = ...) -> None: ...
|
|
def find_function(self, name: str) -> ScriptFunction: ...
|
|
def __getattr__(self, name: str) -> ScriptFunction: ...
|
|
def define(
|
|
self,
|
|
script: str,
|
|
rcb: ResolutionCallback = ...,
|
|
_frames_up: _int = ...,
|
|
): ...
|
|
def get_interface(self, name: str) -> InterfaceType: ...
|
|
def get_functions(self) -> list[ScriptFunction]: ...
|
|
def create_function(
|
|
self,
|
|
name: str,
|
|
graph: Graph,
|
|
shouldMangle: _bool = ...,
|
|
) -> ScriptFunction: ...
|
|
def get_class(self, name: str) -> ClassType: ...
|
|
|
|
class ScriptObject:
|
|
def setattr(self, name: str, value: Any): ...
|
|
def _get_method(self, name: str) -> ScriptMethod: ...
|
|
def _type(self) -> ClassType: ...
|
|
|
|
class ScriptModule(ScriptObject):
|
|
def _method_names(self) -> list[str]: ...
|
|
def _get_method(self, name: str) -> ScriptMethod: ...
|
|
|
|
class LiteScriptModule:
|
|
def __call__(self, *input): ...
|
|
def find_method(self, method_name: str): ...
|
|
def forward(self, *input) -> list[str]: ...
|
|
def run_method(self, method_name: str, *input): ...
|
|
|
|
# NOTE: switch to collections.abc.Callable in python 3.9
|
|
class ScriptFunction(Generic[P, R]):
|
|
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...
|
|
def save(self, filename: str, _extra_files: dict[str, bytes]) -> None: ...
|
|
def save_to_buffer(self, _extra_files: dict[str, bytes]) -> bytes: ...
|
|
@property
|
|
def graph(self) -> Graph: ...
|
|
def inlined_graph(self) -> Graph: ...
|
|
def schema(self) -> FunctionSchema: ...
|
|
def code(self) -> str: ...
|
|
def name(self) -> str: ...
|
|
@property
|
|
def qualified_name(self) -> str: ...
|
|
|
|
# NOTE: switch to collections.abc.Callable in python 3.9
|
|
class ScriptMethod(Generic[P, R]):
|
|
graph: Graph
|
|
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...
|
|
@property
|
|
def owner(self) -> ScriptModule: ...
|
|
@property
|
|
def name(self) -> str: ...
|
|
@property
|
|
def schema(self) -> FunctionSchema: ...
|
|
|
|
class ScriptDict(Generic[K, T]):
|
|
def __init__(self, dict: dict[K, T]) -> None: ...
|
|
def __len__(self) -> _int: ...
|
|
def __contains__(self, key: K) -> _bool: ...
|
|
def __getitem__(self, key: K) -> T: ...
|
|
def __setitem__(self, key: K, value: T) -> None: ...
|
|
def __delitem__(self, key: K) -> None: ...
|
|
def __iter__(self) -> Iterator[K]: ...
|
|
def items(self) -> Iterator[tuple[K, T]]: ...
|
|
def keys(self) -> Iterator[K]: ...
|
|
|
|
class ScriptList(Generic[T]):
|
|
def __init__(self, list: list[T]) -> None: ...
|
|
def __len__(self) -> _int: ...
|
|
def __contains__(self, item: T) -> _bool: ...
|
|
@overload
|
|
def __getitem__(self, idx: _int) -> T: ...
|
|
@overload
|
|
def __getitem__(self, idx: slice) -> ScriptList[T]: ...
|
|
@overload
|
|
def __setitem__(self, idx: _int, value: T) -> None: ...
|
|
@overload
|
|
def __setitem__(self, idx: slice, value: list[T]) -> None: ...
|
|
def __delitem__(self, idx: _int) -> None: ...
|
|
def __iter__(self) -> Iterator[T]: ...
|
|
def count(self, value: T) -> _int: ...
|
|
def remove(self, value: T) -> None: ...
|
|
def append(self, value: T) -> None: ...
|
|
def clear(self) -> None: ...
|
|
@overload
|
|
def extend(self, values: list[T]) -> None: ...
|
|
@overload
|
|
def extend(self, values: Iterable[T]) -> None: ...
|
|
@overload
|
|
def pop(self) -> T: ...
|
|
@overload
|
|
def pop(self, idx: _int) -> T: ...
|
|
|
|
class ModuleDict:
|
|
def __init__(self, mod: ScriptModule) -> None: ...
|
|
def items(self) -> list[tuple[str, Any]]: ...
|
|
|
|
class ParameterDict:
|
|
def __init__(self, mod: ScriptModule) -> None: ...
|
|
|
|
class BufferDict:
|
|
def __init__(self, mod: ScriptModule) -> None: ...
|
|
|
|
# Defined in torch/csrc/jit/api/module.h
|
|
class Module: ...
|
|
|
|
# Defined in torch/csrc/Module.cpp
|
|
def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension
|
|
def _autograd_init() -> _bool: ... # THPAutograd_initExtension
|
|
def _add_docstr(obj: T, doc_obj: str) -> T: ... # THPModule_addDocStr
|
|
def _init_names(arg: Sequence[type]) -> None: ... # THPModule_initNames
|
|
def _has_distributed() -> _bool: ... # THPModule_hasDistributed
|
|
def _set_default_tensor_type(type) -> None: ... # THPModule_setDefaultTensorType
|
|
def _set_default_dtype(d: _dtype) -> None: ... # THPModule_setDefaultDtype
|
|
def _infer_size(arg1: Size, arg2: Size) -> Size: ... # THPModule_inferSize
|
|
def _crash_if_csrc_asan() -> _int: ... # THPModule_crashIfCsrcASAN
|
|
def _crash_if_csrc_ubsan() -> _int: ... # THPModule_crashIfCsrcUBSAN
|
|
def _crash_if_aten_asan() -> _int: ... # THPModule_crashIfATenASAN
|
|
def _show_config() -> str: ... # THPModule_showConfig
|
|
def _cxx_flags() -> str: ... # THPModule_cxxFlags
|
|
def _parallel_info() -> str: ... # THPModule_parallelInfo
|
|
def _get_cpu_capability() -> str: ... # THPModule_getCpuCapability
|
|
def _set_backcompat_broadcast_warn(
|
|
arg: _bool,
|
|
) -> None: ... # THPModule_setBackcompatBroadcastWarn
|
|
def _get_backcompat_broadcast_warn() -> (
|
|
_bool
|
|
): ... # THPModule_getBackcompatBroadcastWarn
|
|
def _set_backcompat_keepdim_warn(
|
|
arg: _bool,
|
|
) -> None: ... # THPModule_setBackcompatKeepdimWarn
|
|
def _get_backcompat_keepdim_warn() -> _bool: ... # THPModule_getBackcompatKeepdimWarn
|
|
def get_num_thread() -> _int: ... # THPModule_getNumThreads
|
|
def set_num_threads(nthreads: _int) -> None: ... # THPModule_setNumThreads
|
|
def get_num_interop_threads() -> _int: ... # THPModule_getNumInteropThreads
|
|
def set_num_interop_threads(
|
|
nthreads: _int,
|
|
) -> None: ... # THPModule_setNumInteropThreads
|
|
def _get_cudnn_enabled() -> _bool: ... # THPModule_userEnabledCuDNN
|
|
def _set_cudnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledCuDNN
|
|
def _get_flash_sdp_enabled() -> _bool: ... # THPModule_userEnabledFusedSDP
|
|
def _set_sdp_use_flash(arg: _bool) -> None: ... # THPModule_setSDPUseFlash
|
|
def _get_mem_efficient_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
|
|
def _set_sdp_use_mem_efficient(
|
|
arg: _bool,
|
|
) -> None: ... # THPModule_setSDPUseMemEfficient
|
|
def _get_math_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
|
|
def _set_sdp_use_math(arg: _bool) -> None: ... # THPModule_setSDPUseMath
|
|
def _get_math_sdp_allow_fp16_bf16_reduction() -> (
|
|
_bool
|
|
): ... # THPModule_allowFP16BF16ReductionMathSDP
|
|
def _set_math_sdp_allow_fp16_bf16_reduction(
|
|
arg: _bool,
|
|
) -> None: ... # THPModule_setAllowFP16BF16ReductionMathSDP
|
|
def _get_overrideable_sdp_enabled() -> (
|
|
_bool
|
|
): ... # THPModule_userEnabledOverrideableSDP
|
|
def _set_sdp_use_overrideable(
|
|
arg: _bool,
|
|
) -> None: ... # THPModule_setSDPUseOverrideable
|
|
def _get_sdp_priority_order() -> list[_int]: ... # THPModule_getSDPPriorityOrder
|
|
def _set_sdp_priority_order(
|
|
arg: list[_int],
|
|
) -> None: ... # THPModule_setSDPPriorityOrder
|
|
def _get_cudnn_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
|
|
def _set_sdp_use_cudnn(arg: _bool) -> None: ... # THPModule_setSDPUseMath
|
|
def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn
|
|
def _set_mkldnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledMkldnn
|
|
def _get_cudnn_benchmark() -> _bool: ... # THPModule_benchmarkCuDNN
|
|
def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN
|
|
def _get_miopen_immediate() -> _bool: ... # THPModule_userImmediateMiopen
|
|
def _set_miopen_immediate(arg: _bool) -> None: ... # THPModule_setUserImmediateMiopen
|
|
def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN
|
|
def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN
|
|
def _get_mkldnn_deterministic() -> _bool: ... # THPModule_deterministicMkldnn
|
|
def _set_mkldnn_deterministic(
|
|
arg: _bool,
|
|
) -> None: ... # THPModule_setDeterministicMkldnn
|
|
def _get_onednn_allow_tf32() -> _bool: ... # THPModule_allowTF32OneDNN
|
|
def _set_onednn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32OneDNN
|
|
def _get_deterministic_algorithms() -> _bool: ... # THPModule_deterministicAlgorithms
|
|
def _get_deterministic_algorithms_warn_only() -> (
|
|
_bool
|
|
): ... # THPModule_deterministicAlgorithmsWarnOnly
|
|
def _set_deterministic_algorithms(
|
|
mode: _bool,
|
|
*,
|
|
warn_only: _bool = ...,
|
|
) -> None: ... # THPModule_setDeterministicAlgorithms
|
|
def _get_deterministic_fill_uninitialized_memory() -> (
|
|
_bool
|
|
): ... # THPModule_deterministicFillUninitializedMemory
|
|
def _set_deterministic_fill_uninitialized_memory(
|
|
arg: _bool,
|
|
) -> None: ... # THPModule_setDeterministicFillUninitializedMemory
|
|
def _get_nnpack_enabled() -> _bool: ... # THPModule_userEnabledNNPACK
|
|
def _set_nnpack_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledNNPACK
|
|
def _get_warnAlways() -> _bool: ... # THPModule_warnAlways
|
|
def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways
|
|
def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN
|
|
def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN
|
|
def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS
|
|
def _set_cublas_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuBLAS
|
|
def _get_float32_matmul_precision() -> str: ... # THPModule_float32MatmulPrecision
|
|
def _set_float32_matmul_precision(
|
|
arg: str,
|
|
) -> None: ... # THPModule_setFloat32MatmulPrecision
|
|
def _get_cublas_allow_fp16_reduced_precision_reduction() -> tuple[
|
|
_bool, _bool
|
|
]: ... # THPModule_allowFP16ReductionCuBLAS
|
|
def _set_cublas_allow_fp16_reduced_precision_reduction(
|
|
arg: _bool,
|
|
allow_splitk: _bool = ...,
|
|
) -> None: ... # THPModule_setAllowFP16ReductionCuBLAS
|
|
def _get_cublas_allow_bf16_reduced_precision_reduction() -> tuple[
|
|
_bool, _bool
|
|
]: ... # THPModule_allowBF16ReductionCuBLAS
|
|
def _set_cublas_allow_bf16_reduced_precision_reduction(
|
|
arg: _bool,
|
|
allow_splitk: _bool = ...,
|
|
) -> None: ... # THPModule_setAllowBF16ReductionCuBLAS
|
|
def _get_cublas_allow_fp16_accumulation() -> (
|
|
_bool
|
|
): ... # THPModule_allowFP16AccumulationCuBLAS
|
|
def _set_cublas_allow_fp16_accumulation(
|
|
arg: _bool,
|
|
) -> None: ... # THPModule_setAllowFP16AccumulationCuBLAS
|
|
def _get_sm_carveout_experimental() -> _int | None: ...
|
|
def _set_sm_carveout_experimental(arg: _int | None) -> None: ...
|
|
def _set_conj(x: Tensor, conj: _bool) -> None: ...
|
|
def _set_neg(x: Tensor, neg: _bool) -> None: ...
|
|
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...
|
|
def _autocast_supported_devices() -> list[str]: ...
|
|
def _meta_in_tls_dispatch_include() -> _bool: ...
|
|
def _stash_obj_in_tls(key: str, arg: Any) -> None: ...
|
|
def _get_obj_in_tls(key: str) -> Any: ...
|
|
def _is_key_in_tls(key: str) -> _bool: ...
|
|
def _select_batch_norm_backend(*args, **kwargs) -> BatchNormBackend: ...
|
|
def _select_conv_backend(*args, **kwargs) -> ConvBackend: ...
|
|
def _conv_determine_backend_memory_format(
|
|
input: Tensor,
|
|
weight: Tensor,
|
|
backend: ConvBackend,
|
|
) -> memory_format: ...
|
|
def _has_storage(x: Tensor) -> _bool: ...
|
|
def _construct_storage_from_data_pointer(
|
|
data_ptr: _int,
|
|
device: torch.device,
|
|
size: _int,
|
|
) -> Storage: ...
|
|
def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ...
|
|
def _group_tensors_by_device_and_dtype(
|
|
nested_tensorlists: list[list[Tensor | None]],
|
|
with_indices: _bool = False,
|
|
) -> dict[
|
|
tuple[torch.device, torch.dtype],
|
|
tuple[list[list[Tensor | None]], list[_int]],
|
|
]: ...
|
|
def _initCrashHandler() -> None: ...
|
|
|
|
# NB: There is no Capsule type in typing, see
|
|
# https://github.com/python/cpython/issues/109562
|
|
def _to_dlpack(
|
|
data: Tensor,
|
|
dl_device: tuple[IntEnum, _int] | None = None,
|
|
copy: _bool | None = None,
|
|
) -> Any: ... # THPModule_toDLPack
|
|
def _to_dlpack_versioned(
|
|
data: Tensor,
|
|
dl_device: tuple[IntEnum, _int] | None = None,
|
|
copy: _bool | None = None,
|
|
) -> Any: ... # THPModule_toDLPackVersioned
|
|
def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack
|
|
def _torchDeviceToDLDevice(
|
|
device: torch.device,
|
|
) -> tuple[_int, _int]: ... # THPModule_torchDeviceToDLDevice
|
|
def _get_cpp_backtrace(
|
|
frames_to_skip: _int,
|
|
maximum_number_of_frames: _int,
|
|
) -> str: ... # THPModule_getCppBacktrace
|
|
def set_flush_denormal(arg: _bool) -> _bool: ... # THPModule_setFlushDenormal
|
|
def get_default_dtype() -> _dtype: ... # THPModule_getDefaultDtype
|
|
def _get_default_device() -> str: ... # THPModule_getDefaultDevice
|
|
def _get_qengine() -> _int: ... # THPModule_qEngine
|
|
def _set_qengine(qengine: _int) -> None: ... # THPModule_setQEngine
|
|
def _supported_qengines() -> list[_int]: ... # THPModule_supportedQEngines
|
|
def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK
|
|
def _check_sparse_tensor_invariants() -> (
|
|
_bool
|
|
): ... # THPModule_checkSparseTensorInvariants
|
|
def _set_check_sparse_tensor_invariants(
|
|
arg: _bool,
|
|
) -> None: ... # THPModule_setCheckSparseTensorInvariants
|
|
def _is_default_mobile_cpu_allocator_set() -> (
|
|
_bool
|
|
): ... # THPModule_isDefaultMobileCPUAllocatorSet
|
|
def _set_default_mobile_cpu_allocator() -> (
|
|
None
|
|
): ... # THPModule_setDefaultMobileCPUAllocator
|
|
def _unset_default_mobile_cpu_allocator() -> (
|
|
None
|
|
): ... # THPModule_unsetDefaultMobileCPUAllocator
|
|
def _is_torch_function_enabled() -> _bool: ... # THPModule_isEnabledTorchFunction
|
|
def _is_torch_function_all_disabled() -> (
|
|
_bool
|
|
): ... # THPModule_isAllDisabledTorchFunction
|
|
def _has_torch_function(
|
|
args: Iterable[Any],
|
|
) -> _bool: ... # THPModule_has_torch_function
|
|
def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary
|
|
def _has_torch_function_variadic(
|
|
*args: Any,
|
|
) -> _bool: ... # THPModule_has_torch_function_variadic
|
|
def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting
|
|
def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting
|
|
def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython
|
|
def _log_api_usage_metadata(
|
|
event: str,
|
|
metadata_map: dict[str, str],
|
|
) -> None: ... # LogAPIUsageMetadataFromPython
|
|
def _demangle(str) -> str: ... # c10::demangle
|
|
def _disabled_torch_function_impl(
|
|
func: Callable,
|
|
types: Iterable[type],
|
|
args: tuple,
|
|
kwargs: dict,
|
|
) -> Any: ... # THPModule_disable_torch_function
|
|
def _disabled_torch_dispatch_impl(
|
|
func: Callable,
|
|
types: Iterable[type],
|
|
args: tuple,
|
|
kwargs: dict,
|
|
) -> Any: ... # THPModule_disable_dispatch_function
|
|
def _get_linalg_preferred_backend() -> _LinalgBackend: ...
|
|
def _set_linalg_preferred_backend(arg: _LinalgBackend): ...
|
|
def _get_fp32_precision_getter(backend: str, op: str) -> str: ...
|
|
def _set_fp32_precision_setter(backend: str, op: str, value: str) -> str: ...
|
|
def _ensureCUDADeviceGuardSet() -> None: ...
|
|
|
|
class _LinalgBackend:
|
|
Default: _LinalgBackend
|
|
Cusolver: _LinalgBackend
|
|
Magma: _LinalgBackend
|
|
|
|
# mypy error:
|
|
# Detected enum "torch._C.BatchNormBackend" in a type stub with zero
|
|
# members. There is a chance this is due to a recent change in the semantics
|
|
# of enum membership. If so, use `member = value` to mark an enum member,
|
|
# instead of `member: type`
|
|
class BatchNormBackend(Enum): ... # type: ignore[misc]
|
|
|
|
def _get_blas_preferred_backend() -> _BlasBackend: ...
|
|
def _set_blas_preferred_backend(arg: _BlasBackend): ...
|
|
|
|
class _BlasBackend:
|
|
Default: _BlasBackend
|
|
Cublas: _BlasBackend
|
|
Cublaslt: _BlasBackend
|
|
Ck: _BlasBackend
|
|
|
|
def _get_rocm_fa_preferred_backend() -> torch._C._ROCmFABackend: ...
|
|
def _set_rocm_fa_preferred_backend(arg: torch._C._ROCmFABackend): ...
|
|
|
|
class _ROCmFABackend:
|
|
Default: _ROCmFABackend
|
|
AOTriton: _ROCmFABackend
|
|
Ck: _ROCmFABackend
|
|
|
|
# mypy error:
|
|
# Error (MYPY) [misc]
|
|
# Detected enum "torch._C.ConvBackend" in a type stub with zero members.
|
|
# There is a chance this is due to a recent change in the semantics of enum
|
|
# membership. If so, use `member = value` to mark an enum member, instead of
|
|
# `member: type`
|
|
class ConvBackend(Enum): ... # type: ignore[misc]
|
|
|
|
class Tag(Enum):
|
|
${tag_attributes}
|
|
|
|
# Defined in `valgrind.h` and `callgrind.h` respectively.
|
|
def _valgrind_supported_platform() -> _bool: ... # NVALGRIND
|
|
def _valgrind_toggle() -> None: ... # CALLGRIND_TOGGLE_COLLECT
|
|
def _valgrind_toggle_and_dump_stats() -> (
|
|
None
|
|
): ... # CALLGRIND_TOGGLE_COLLECT and CALLGRIND_DUMP_STATS
|
|
|
|
has_openmp: _bool
|
|
has_mkl: _bool
|
|
_has_kleidiai: _bool
|
|
_has_mps: _bool
|
|
has_lapack: _bool
|
|
_has_cuda: _bool
|
|
_has_magma: _bool
|
|
_has_xpu: _bool
|
|
_has_mkldnn: _bool
|
|
_has_mkldnn_acl: _bool
|
|
_has_cudnn: _bool
|
|
_has_cusparselt: _bool
|
|
has_spectral: _bool
|
|
_GLIBCXX_USE_CXX11_ABI: _bool
|
|
default_generator: Generator
|
|
|
|
# Defined in torch/csrc/autograd/init.cpp
|
|
def _set_grad_enabled(enabled: _bool) -> None: ...
|
|
def is_grad_enabled() -> _bool: ...
|
|
def _set_fwd_grad_enabled(enabled: _bool) -> None: ...
|
|
def _is_fwd_grad_enabled() -> _bool: ...
|
|
def _any_requires_grad(*args, **kwargs) -> _bool: ...
|
|
def _any_output_is_alias_to_input_or_output(*args, **kwargs) -> _bool: ...
|
|
def is_inference_mode_enabled() -> _bool: ...
|
|
@overload
|
|
def set_autocast_enabled(device_type: str, enabled: _bool) -> None: ...
|
|
@overload
|
|
def set_autocast_enabled(enabled: _bool) -> None: ...
|
|
@overload
|
|
def is_autocast_enabled(device_type: str) -> _bool: ...
|
|
@overload
|
|
def is_autocast_enabled() -> _bool: ...
|
|
def set_autocast_dtype(device_type: str, dtype: _dtype) -> None: ...
|
|
def get_autocast_dtype(device_type: str) -> _dtype: ...
|
|
def clear_autocast_cache() -> None: ...
|
|
def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
|
|
def is_autocast_cpu_enabled() -> _bool: ...
|
|
def _is_any_autocast_enabled() -> _bool: ...
|
|
def _is_autocast_available(device_type: str) -> _bool: ...
|
|
def set_autocast_cpu_dtype(dtype: _dtype) -> None: ...
|
|
def set_autocast_gpu_dtype(dtype: _dtype) -> None: ...
|
|
def get_autocast_cpu_dtype() -> _dtype: ...
|
|
def get_autocast_gpu_dtype() -> _dtype: ...
|
|
def autocast_increment_nesting() -> _int: ...
|
|
def autocast_decrement_nesting() -> _int: ...
|
|
def is_autocast_cache_enabled() -> _bool: ...
|
|
def set_autocast_cache_enabled(enabled: _bool) -> None: ...
|
|
def _increment_version(tensors: Iterable[Tensor]) -> None: ...
|
|
def set_anomaly_enabled(enabled: _bool, check_nan: _bool = True) -> None: ...
|
|
def is_anomaly_enabled() -> _bool: ...
|
|
def is_anomaly_check_nan_enabled() -> _bool: ...
|
|
def _is_multithreading_enabled() -> _bool: ...
|
|
def _set_multithreading_enabled(enabled: _bool) -> None: ...
|
|
def _set_view_replay_enabled(enabled: _bool) -> None: ...
|
|
def _is_view_replay_enabled() -> _bool: ...
|
|
def _enter_dual_level() -> _int: ...
|
|
def _exit_dual_level(level: _int) -> None: ...
|
|
def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ...
|
|
def _unpack_dual(tensor: Tensor, level: _int) -> Tensor: ...
|
|
def __set_forward_AD_enabled(enabled: _bool) -> None: ...
|
|
def __is_forward_AD_enabled() -> _bool: ...
|
|
def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
|
|
def _reset_default_hooks() -> None: ...
|
|
def _is_torch_function_mode_enabled() -> _bool: ...
|
|
def _push_on_torch_function_stack(cls: Any) -> None: ...
|
|
def _pop_torch_function_stack() -> Any: ...
|
|
def _get_function_stack_at(idx: _int) -> Any: ...
|
|
def _len_torch_function_stack() -> _int: ...
|
|
def _set_torch_dispatch_mode(cls: Any) -> None: ...
|
|
def _push_on_torch_dispatch_stack(cls: TorchDispatchMode) -> None: ...
|
|
def _pop_torch_dispatch_stack(mode_key: _TorchDispatchModeKey | None = None) -> Any: ...
|
|
def _get_dispatch_mode(mode_key: _TorchDispatchModeKey | None) -> Any: ...
|
|
def _unset_dispatch_mode(mode: _TorchDispatchModeKey) -> TorchDispatchMode | None: ...
|
|
def _set_dispatch_mode(mode: TorchDispatchMode) -> None: ...
|
|
def _get_dispatch_stack_at(idx: _int) -> Any: ...
|
|
def _len_torch_dispatch_stack() -> _int: ...
|
|
def _activate_gpu_trace() -> None: ...
|
|
|
|
class _DisableTorchDispatch:
|
|
def __init__(self) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
class _EnableTorchFunction:
|
|
def __init__(self) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
class _EnablePythonDispatcher:
|
|
def __init__(self) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
class _DisablePythonDispatcher:
|
|
def __init__(self) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
class _EnablePreDispatch:
|
|
def __init__(self) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
class _DisableFuncTorch:
|
|
def __init__(self) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
class _DisableAutocast:
|
|
def __init__(self) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
class _InferenceMode:
|
|
def __init__(self, enabled: _bool) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
def _set_autograd_fallback_mode(mode: str) -> None: ...
|
|
def _get_autograd_fallback_mode() -> str: ...
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
class LoggerBase: ...
|
|
class NoopLogger(LoggerBase): ...
|
|
class LockingLogger(LoggerBase): ...
|
|
|
|
class AggregationType(Enum):
|
|
SUM = 0
|
|
AVG = 1
|
|
|
|
class FileCheck:
|
|
def run(self, test_string: str) -> None: ...
|
|
def check(self, test_string: str) -> FileCheck: ...
|
|
def check_not(self, test_string: str) -> FileCheck: ...
|
|
def check_same(self, test_string: str) -> FileCheck: ...
|
|
def check_next(self, test_string: str) -> FileCheck: ...
|
|
def check_count(
|
|
self,
|
|
test_string: str,
|
|
count: _int,
|
|
exactly: _bool = False,
|
|
) -> FileCheck: ...
|
|
def check_dag(self, test_string: str) -> FileCheck: ...
|
|
def check_source_highlighted(self, test_string: str) -> FileCheck: ...
|
|
def check_regex(self, test_string: str) -> FileCheck: ...
|
|
|
|
# Defined in torch/csrc/jit/python/init.cpp
|
|
class PyTorchFileReader:
|
|
@overload
|
|
def __init__(self, name: str) -> None: ...
|
|
@overload
|
|
def __init__(self, buffer: IO[bytes]) -> None: ...
|
|
def get_record(self, name: str) -> bytes: ...
|
|
def get_all_records(self) -> list[str]: ...
|
|
def serialization_id(self) -> str: ...
|
|
|
|
class PyTorchFileWriter:
|
|
@overload
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
compute_crc32: _bool = True,
|
|
storage_alignment: _int = 64,
|
|
) -> None: ...
|
|
@overload
|
|
def __init__(
|
|
self,
|
|
buffer: IO[bytes],
|
|
compute_crc32: _bool = True,
|
|
storage_alignment: _int = 64,
|
|
) -> None: ...
|
|
def write_record(
|
|
self,
|
|
name: str,
|
|
data: Storage | bytes | _int,
|
|
size: _int,
|
|
) -> None: ...
|
|
def write_end_of_file(self) -> None: ...
|
|
def set_min_version(self, version: _int) -> None: ...
|
|
def get_all_written_records(self) -> list[str]: ...
|
|
def archive_name(self) -> str: ...
|
|
def serialization_id(self) -> str: ...
|
|
|
|
def _jit_get_inline_everything_mode() -> _bool: ...
|
|
def _jit_set_inline_everything_mode(enabled: _bool) -> None: ...
|
|
def _jit_get_logging_option() -> str: ...
|
|
def _jit_set_logging_option(option: str) -> None: ...
|
|
def _jit_set_logging_stream(stream_name: str) -> None: ...
|
|
def _jit_pass_cse(Graph) -> _bool: ...
|
|
def _jit_pass_dce(Graph) -> None: ...
|
|
def _jit_pass_dce_graph(Graph) -> None: ...
|
|
def _jit_pass_lint(Graph) -> None: ...
|
|
def _make_opaque_object(payload: Any) -> ScriptObject: ...
|
|
def _get_opaque_object_payload(obj: ScriptObject) -> Any: ...
|
|
def _set_opaque_object_payload(obj: ScriptObject, payload: Any) -> None: ...
|
|
def _register_opaque_type(type_name: str) -> None: ...
|
|
def _is_opaque_type_registered(type_name: str) -> _bool: ...
|
|
|
|
# Defined in torch/csrc/jit/python/python_custom_class.cpp
|
|
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...
|
|
|
|
# Defined in torch/csrc/Module.cpp
|
|
def _rename_privateuse1_backend(backend: str) -> None: ...
|
|
def _get_privateuse1_backend_name() -> str: ...
|
|
|
|
# Defined in torch/csrc/Generator.cpp
|
|
class Generator:
|
|
device: _device
|
|
def __init__(self, device: DeviceLikeType | None = None) -> None: ...
|
|
def __reduce__(
|
|
self,
|
|
) -> tuple[type[Generator], tuple[_device], tuple[_int, _int | None, Tensor]]: ...
|
|
def __setstate__(self, state: tuple[_int, _int | None, Tensor]) -> None: ...
|
|
def get_state(self) -> Tensor: ...
|
|
def set_state(self, _new_state: Tensor) -> Generator: ...
|
|
def clone_state(self) -> Generator: ...
|
|
def graphsafe_get_state(self) -> Generator: ...
|
|
def graphsafe_set_state(self, _new_state: Generator) -> Generator: ...
|
|
def set_offset(self, offset: _int) -> Generator: ...
|
|
def get_offset(self) -> _int: ...
|
|
def manual_seed(self, seed: _int) -> Generator: ...
|
|
def seed(self) -> _int: ...
|
|
def initial_seed(self) -> _int: ...
|
|
|
|
# Defined in torch/csrc/utils/python_dispatch.cpp
|
|
|
|
class _DispatchOperatorHandle:
|
|
def schema(self) -> FunctionSchema: ...
|
|
def debug(self) -> str: ...
|
|
def redispatch_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Any: ...
|
|
|
|
class _DispatchModule:
|
|
def reset(self) -> None: ...
|
|
def def_(self, schema: str, alias: str = "") -> _DispatchModule: ...
|
|
def def_legacy(self, schema: str) -> _DispatchModule: ...
|
|
def def_name_t_t(
|
|
self,
|
|
name: str,
|
|
dispatch: str,
|
|
debug: str = "default_def_name_t_t",
|
|
) -> _DispatchModule: ...
|
|
def def_schema_t_t(
|
|
self,
|
|
schema: str,
|
|
dispatch: str,
|
|
alias: str,
|
|
debug: str = "default_def_schema_t_t",
|
|
) -> _DispatchModule: ...
|
|
def impl_t_t(
|
|
self,
|
|
name: str,
|
|
dispatch: str,
|
|
debug: str = "impl_t_t",
|
|
) -> _DispatchModule: ...
|
|
def impl_with_aoti_compile(
|
|
self,
|
|
ns: str,
|
|
op_name_with_overload: str,
|
|
dispatch: _dispatchkey,
|
|
) -> None: ...
|
|
def impl(self, name: str, dispatch: _dispatchkey, func: Callable) -> None: ...
|
|
def define(self, schema: str, alias: str = "") -> str: ...
|
|
def fallback_fallthrough(self, dispatch: str = "") -> _DispatchModule: ...
|
|
def fallback(
|
|
self,
|
|
dispatch: _dispatchkey,
|
|
func: Callable,
|
|
with_keyset: _bool = False,
|
|
) -> None: ...
|
|
|
|
_after_ADInplaceOrView_keyset: DispatchKeySet
|
|
_after_autograd_keyset: DispatchKeySet
|
|
|
|
class _SafeKernelFunction:
|
|
def call_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Any: ...
|
|
@property
|
|
def op_handle(self) -> _DispatchOperatorHandle: ...
|
|
|
|
def _dispatch_library(
|
|
kind: str,
|
|
name: str,
|
|
dispatch: str,
|
|
file: str = "",
|
|
linenum: Any = 0,
|
|
) -> _DispatchModule: ...
|
|
def _dispatch_dump(name: str) -> str: ...
|
|
def _dispatch_dump_table(name: str) -> str: ...
|
|
def _dispatch_check_invariants(name: str) -> None: ...
|
|
def _dispatch_check_all_invariants() -> None: ...
|
|
def _dispatch_call_boxed(handle: _DispatchOperatorHandle, *args, **kwargs) -> Any: ...
|
|
def _dispatch_find_schema_or_throw(
|
|
name: str,
|
|
overload_name: str,
|
|
) -> _DispatchOperatorHandle: ...
|
|
def _dispatch_set_report_error_callback(
|
|
handle: _DispatchOperatorHandle,
|
|
callback: Callable,
|
|
) -> None: ...
|
|
def _dispatch_has_kernel(name: str) -> _bool: ...
|
|
def _dispatch_has_kernel_for_dispatch_key(
|
|
name: str,
|
|
dispatch: _dispatchkey,
|
|
) -> _bool: ...
|
|
def _dispatch_has_kernel_for_any_dispatch_key(
|
|
name: str,
|
|
dispatch_key_set: DispatchKeySet,
|
|
) -> _bool: ...
|
|
def _dispatch_kernel_for_dispatch_key_is_fallthrough(
|
|
name: str,
|
|
dispatch: _dispatchkey,
|
|
) -> _bool: ...
|
|
def _dispatch_has_computed_kernel_for_dispatch_key(
|
|
name: str,
|
|
dispatch: _dispatchkey,
|
|
) -> _bool: ...
|
|
def _dispatch_get_computed_kernel_for_dispatch_key(
|
|
name: str,
|
|
dispatch: _dispatchkey,
|
|
) -> _SafeKernelFunction: ...
|
|
def _dispatch_find_dangling_impls() -> list[str]: ...
|
|
def _dispatch_get_all_op_names() -> list[str]: ...
|
|
def _dispatch_tls_set_dispatch_key_excluded(
|
|
dispatch: _dispatchkey,
|
|
val: _bool,
|
|
) -> None: ...
|
|
def _dispatch_tls_is_dispatch_key_excluded(dispatch: _dispatchkey) -> _bool: ...
|
|
def _dispatch_tls_set_dispatch_key_included(
|
|
dispatch: _dispatchkey,
|
|
val: _bool,
|
|
) -> None: ...
|
|
def _dispatch_tls_is_dispatch_key_included(dispatch: _dispatchkey) -> _bool: ...
|
|
def _dispatch_isTensorSubclassLike(tensor: Tensor) -> _bool: ...
|
|
def _dispatch_key_name(dispatch: _dispatchkey) -> str: ...
|
|
def _dispatch_key_for_device(device_type: str) -> str: ...
|
|
def _parse_dispatch_key(key: str) -> DispatchKey | None: ...
|
|
def _dispatch_key_parse(dispatch: _dispatchkey) -> DispatchKey: ...
|
|
def _dispatch_num_backends() -> _int: ...
|
|
def _dispatch_pystub(name: str, overload: str) -> tuple[str, str] | None: ...
|
|
def _dispatch_is_alias_key(dispatch: _dispatchkey) -> _bool: ...
|
|
def _functionality_to_backend_keys(dispatch: _dispatchkey) -> list[DispatchKey]: ...
|
|
def _functionalization_reapply_views_tls() -> _bool: ...
|
|
def _only_lift_cpu_tensors() -> _bool: ...
|
|
def _set_only_lift_cpu_tensors(value: _bool) -> None: ...
|
|
def _set_throw_on_mutable_data_ptr(tensor: Tensor) -> None: ...
|
|
def _set_warn_deprecated_on_mutable_data_ptr(tensor: Tensor) -> None: ...
|
|
|
|
class DispatchKey(Enum):
|
|
${dispatch_key_hints}
|
|
|
|
class DispatchKeySet:
|
|
def __init__(self, key: DispatchKey) -> None: ...
|
|
def __or__(self, other: DispatchKeySet) -> DispatchKeySet: ...
|
|
def __sub__(self, other: DispatchKeySet) -> DispatchKeySet: ...
|
|
def __and__(self, other: DispatchKeySet) -> DispatchKeySet: ...
|
|
def raw_repr(self) -> _int: ...
|
|
@staticmethod
|
|
def from_raw_repr(raw: _int) -> DispatchKeySet: ...
|
|
def highestPriorityTypeId(self) -> DispatchKey: ...
|
|
def has(self, k: _dispatchkey) -> _bool: ...
|
|
def add(self, k: _dispatchkey) -> DispatchKeySet: ...
|
|
def remove(self, k: _dispatchkey) -> DispatchKeySet: ...
|
|
|
|
_dispatch_autogradother_backends: DispatchKeySet
|
|
_additional_keys_to_prop_for_wrapper_tensors: DispatchKeySet
|
|
|
|
def _dispatch_has_backend_fallback(dispatch: _dispatchkey) -> _bool: ...
|
|
def _dispatch_keyset_full_after(t: _dispatchkey) -> DispatchKeySet: ...
|
|
def _dispatch_keyset_full() -> DispatchKeySet: ...
|
|
def _dispatch_keyset_to_string(keyset: DispatchKeySet) -> str: ...
|
|
def _dispatch_get_backend_keyset_from_autograd(
|
|
dispatch: _dispatchkey,
|
|
) -> DispatchKeySet: ...
|
|
def _dispatch_keys(tensor: Tensor) -> DispatchKeySet: ...
|
|
def _dispatch_tls_local_exclude_set() -> DispatchKeySet: ...
|
|
def _dispatch_tls_local_include_set() -> DispatchKeySet: ...
|
|
def _dispatch_is_included_in_alias(
|
|
dispatch_a: _dispatchkey,
|
|
dispatch_b: _dispatchkey,
|
|
) -> _bool: ...
|
|
def _propagate_xla_data(a: Tensor, b: Tensor) -> None: ...
|
|
def _replace_(a: Tensor, b: Tensor) -> None: ...
|
|
def _commit_update(a: Tensor) -> None: ...
|
|
|
|
class _ExcludeDispatchKeyGuard:
|
|
def __init__(self, keyset: DispatchKeySet) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
class _IncludeDispatchKeyGuard:
|
|
def __init__(self, k: DispatchKey) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
class _ForceDispatchKeyGuard:
|
|
def __init__(self, include: DispatchKeySet, exclude: DispatchKeySet) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
class _PreserveDispatchKeyGuard:
|
|
def __init__(self) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
class _AutoDispatchBelowAutograd:
|
|
def __init__(self) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
class _AutoDispatchBelowADInplaceOrView:
|
|
def __init__(self) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ...
|
|
def _dispatch_get_registrations_for_dispatch_key(
|
|
dispatch_key: str = "",
|
|
) -> list[str]: ...
|
|
def _are_functorch_transforms_active() -> _bool: ...
|
|
|
|
# Define in torch/csrc/autograd/init.cpp
|
|
def _set_python_dispatcher(dispatcher: object) -> None: ...
|
|
def _get_nested_int(id: _int, coeff: _int) -> SymInt: ...
|
|
def _get_constant_bool_symnode(val: _bool) -> Any: ...
|
|
|
|
class _TorchDispatchModeKey(Enum):
|
|
${torch_dispatch_mode_key_hints}
|
|
|
|
class _SetExcludeDispatchKeyGuard:
|
|
def __init__(self, k: DispatchKey, enabled: _bool) -> None: ...
|
|
def __enter__(self): ...
|
|
def __exit__(self, *exc_info: object) -> None: ...
|
|
|
|
def _get_dtensor_allow_implicit_replication() -> _bool: ...
|
|
def _set_dtensor_allow_implicit_replication(value: _bool) -> None: ...
|
|
|
|
# Defined in torch/csrc/utils/schema_info.h
|
|
|
|
class _SchemaInfo:
|
|
def __init__(self, schema: FunctionSchema) -> None: ...
|
|
@overload
|
|
def is_mutable(self) -> _bool: ...
|
|
@overload
|
|
def is_mutable(self, name: str) -> _bool: ...
|
|
def has_argument(self, name: str) -> _bool: ...
|
|
|
|
# Defined in torch/csrc/utils/init.cpp
|
|
class BenchmarkConfig:
|
|
num_calling_threads: _int
|
|
num_worker_threads: _int
|
|
num_warmup_iters: _int
|
|
num_iters: _int
|
|
profiler_output_path: str
|
|
|
|
class BenchmarkExecutionStats:
|
|
latency_avg_ms: _float
|
|
num_iters: _int
|
|
|
|
class ThroughputBenchmark:
|
|
def __init__(self, module: Any) -> None: ...
|
|
def add_input(self, *args: Any, **kwargs: Any) -> None: ...
|
|
def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
|
|
def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ...
|
|
|
|
# Defined in torch/csrc/Storage.cpp
|
|
${legacy_storage_base_hints}
|
|
|
|
# TODO: where
|
|
${legacy_class_hints}
|
|
|
|
# Defined in torch/csrc/autograd/python_engine.cpp
|
|
class _ImperativeEngine:
|
|
def queue_callback(self, callback: Callable[[], None]) -> None: ...
|
|
def run_backward(self, *args: Any, **kwargs: Any) -> tuple[Tensor, ...]: ...
|
|
def is_checkpoint_valid(self) -> _bool: ...
|
|
|
|
# Defined in torch/csrc/autograd/python_variable.cpp
|
|
class _TensorMeta(type): ...
|
|
|
|
${index_type_def}
|
|
|
|
# Defined in torch/csrc/autograd/python_variable.cpp
|
|
class TensorBase(metaclass=_TensorMeta):
|
|
requires_grad: _bool
|
|
retains_grad: _bool
|
|
shape: Size
|
|
data: Tensor
|
|
names: list[str]
|
|
device: _device
|
|
dtype: _dtype
|
|
grad_dtype: _dtype | None
|
|
layout: _layout
|
|
real: Tensor
|
|
imag: Tensor
|
|
T: Tensor
|
|
H: Tensor
|
|
mT: Tensor
|
|
mH: Tensor
|
|
ndim: _int
|
|
output_nr: _int
|
|
_version: _int
|
|
_base: Tensor | None
|
|
_cdata: _int
|
|
grad_fn: _Node | None
|
|
_grad_fn: Any
|
|
_grad: Tensor | None
|
|
grad: Tensor | None
|
|
_backward_hooks: dict[_int, Callable[[Tensor], Tensor | None]] | None
|
|
nbytes: _int
|
|
itemsize: _int
|
|
_has_symbolic_sizes_strides: _bool
|
|
|
|
def _view_func_unsafe(
|
|
self,
|
|
new_base: Tensor,
|
|
symint_visitor_fn: Callable[[_int], _int] | None = None,
|
|
tensor_visitor_fn: Callable[[Tensor], Tensor] | None = None,
|
|
): ...
|
|
${tensor_method_hints}
|
|
|
|
_TensorBase = TensorBase
|
|
|
|
def _DTensor_OpSchema_post_init(self: OpSchema) -> None: ...
|
|
def _DTensor_OpSchema_recompute_comparison_key(self: OpSchema) -> None: ...
|
|
|
|
# Defined in torch/csrc/multiprocessing/init.cpp
|
|
def _multiprocessing_init() -> None: ...
|
|
def _set_thread_name(name: str) -> None: ...
|
|
def _get_thread_name() -> str: ...
|
|
|
|
# Defined in torch/csrc/Module.cpp
|
|
def _accelerator_hooks_device_count() -> _int: ...
|
|
def _accelerator_hooks_set_current_device(device_index: _int) -> None: ...
|
|
def _accelerator_hooks_get_current_device() -> _int: ...
|
|
def _accelerator_hooks_exchange_device(device_index: _int) -> _int: ...
|
|
def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int: ...
|
|
def _get_accelerator(check: _bool = False) -> _device: ...
|
|
def _storage_Use_Count(storage_ptr: _int) -> _int: ...
|
|
|
|
# Defined in torch/csrc/mtia/Module.cpp
|
|
def _mtia_init() -> None: ...
|
|
def _mtia_isBuilt() -> _bool: ...
|
|
def _mtia_isInBadFork() -> _bool: ...
|
|
def _mtia_deviceSynchronize() -> None: ...
|
|
def _mtia_getCurrentStream(device: _int) -> Stream: ...
|
|
def _mtia_getCurrentRawStream(device: _int) -> _int: ...
|
|
def _mtia_setCurrentStream(stream: Stream) -> None: ...
|
|
def _mtia_getDefaultStream(device: _int) -> Stream: ...
|
|
def _mtia_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
|
|
def _mtia_memoryStats(device: _int) -> dict[str, Any]: ...
|
|
def _mtia_getDeviceCapability(device: _int) -> tuple[_int, _int]: ...
|
|
def _mtia_getDeviceProperties(device: _int) -> dict[str, Any]: ...
|
|
def _mtia_emptyCache() -> None: ...
|
|
def _mtia_recordMemoryHistory(
|
|
enabled: str | None,
|
|
stacks: str,
|
|
max_entries,
|
|
) -> None: ...
|
|
def _mtia_memorySnapshot() -> dict[str, Any]: ...
|
|
def _mtia_attachOutOfMemoryObserver(
|
|
observer: Callable[[_int, _int, _int, _int], None],
|
|
) -> None: ...
|
|
def _mtia_getDeviceCount() -> _int: ...
|
|
def _mtia_resetPeakMemoryStats(device: _int) -> None: ...
|
|
|
|
# Defined in torch/csrc/mps/Module.cpp
|
|
def _mps_deviceSynchronize() -> None: ...
|
|
def _mps_get_core_count() -> _int: ...
|
|
def _mps_get_default_generator() -> Generator: ...
|
|
def _mps_get_name() -> _str: ...
|
|
def _mps_emptyCache() -> None: ...
|
|
def _mps_setMemoryFraction(fraction: _float) -> None: ...
|
|
def _mps_currentAllocatedMemory() -> _int: ...
|
|
def _mps_driverAllocatedMemory() -> _int: ...
|
|
def _mps_recommendedMaxMemory() -> _int: ...
|
|
def _mps_is_available() -> _bool: ...
|
|
def _mps_is_on_macos_or_newer(major: _int, minor: _int) -> _bool: ...
|
|
def _mps_profilerStartTrace(mode: str, wait_until_completed: _bool) -> None: ...
|
|
def _mps_profilerStopTrace() -> None: ...
|
|
def _mps_acquireEvent(enable_timing: _bool) -> _int: ...
|
|
def _mps_releaseEvent(event_id: _int) -> None: ...
|
|
def _mps_recordEvent(event_id: _int) -> None: ...
|
|
def _mps_waitForEvent(event_id: _int) -> None: ...
|
|
def _mps_synchronizeEvent(event_id: _int) -> None: ...
|
|
def _mps_queryEvent(event_id: _int) -> _bool: ...
|
|
def _mps_elapsedTimeOfEvents(start_event_id: _int, end_event_id: _int) -> _float: ...
|
|
def _mps_isCaptureEnabled() -> _bool: ...
|
|
def _mps_isCapturing() -> _bool: ...
|
|
def _mps_startCapture(name: str) -> None: ...
|
|
def _mps_stopCapture() -> None: ...
|
|
|
|
# Defined in torch/csrc/cuda/Module.cpp
|
|
def _cuda_getCurrentStream(device: _int) -> tuple: ...
|
|
def _cuda_getCurrentRawStream(device: _int) -> _int: ...
|
|
def _cuda_getDefaultStream(device: _int) -> tuple: ...
|
|
def _cuda_getStreamFromExternal(data_ptr: _int, device_index: _int) -> tuple: ...
|
|
def _cuda_getCurrentBlasHandle() -> _int: ...
|
|
def _cuda_clearCublasWorkspaces() -> None: ...
|
|
def _cuda_setDevice(device: _int) -> None: ...
|
|
def _cuda_exchangeDevice(device: _int) -> _int: ...
|
|
def _cuda_maybeExchangeDevice(device: _int) -> _int: ...
|
|
def _cuda_getDevice() -> _int: ...
|
|
def _cuda_getDeviceCount() -> _int: ...
|
|
def _cuda_set_sync_debug_mode(warn_level: _int | str) -> None: ...
|
|
def _cuda_get_sync_debug_mode() -> _int: ...
|
|
def _cuda_sleep(cycles: _int) -> None: ...
|
|
def _cuda_synchronize() -> None: ...
|
|
def _cuda_ipc_collect() -> None: ...
|
|
def _cuda_getArchFlags() -> str | None: ...
|
|
def _cuda_init() -> None: ...
|
|
def _cuda_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
|
|
def _cuda_getCompiledVersion() -> _int: ...
|
|
def _cuda_cudaHostAllocator() -> _int: ...
|
|
def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ...
|
|
def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ...
|
|
def _cuda_cudaCachingAllocator_enable(val: _bool) -> None: ...
|
|
def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ...
|
|
def _cuda_beginAllocateToPool(device: _int, mempool_id: tuple[_int, _int]) -> None: ...
|
|
def _cuda_beginAllocateCurrentThreadToPool(
|
|
device: _int,
|
|
mempool_id: tuple[_int, _int],
|
|
) -> None: ...
|
|
def _cuda_endAllocateToPool(device: _int, mempool_id: tuple[_int, _int]) -> None: ...
|
|
def _cuda_beginAllocateCurrentStreamToPool(
|
|
device: _int,
|
|
mempool_id: tuple[_int, _int],
|
|
) -> None: ...
|
|
def _cuda_releasePool(device: _int, mempool_id: tuple[_int, _int]) -> None: ...
|
|
def _cuda_checkPoolLiveAllocations(
|
|
device: _int,
|
|
mempool_id: tuple[_int, _int],
|
|
expected_live_allocations: set,
|
|
) -> _bool: ...
|
|
def _cuda_setCheckpointPoolState(
|
|
device: _int,
|
|
state: _cuda_CUDAAllocator_AllocatorState,
|
|
stale_storages: list[_int],
|
|
storages_to_add_deleters_to: list[_int],
|
|
) -> None: ...
|
|
def _cuda_getMemoryFraction(device: _int) -> _float: ...
|
|
def _cuda_setMemoryFraction(fraction: _float, device: _int) -> None: ...
|
|
def _cuda_emptyCache() -> None: ...
|
|
def _cuda_memoryStats(device: _int) -> dict[str, Any]: ...
|
|
def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ...
|
|
def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
|
|
def _cuda_hostMemoryStats() -> dict[str, Any]: ...
|
|
def _cuda_resetAccumulatedHostMemoryStats() -> None: ...
|
|
def _cuda_resetPeakHostMemoryStats() -> None: ...
|
|
def _cuda_memorySnapshot(mempool_id: tuple[_int, _int] | None) -> dict[str, Any]: ...
|
|
def _cuda_record_memory_history_legacy(
|
|
enabled: _bool,
|
|
record_context: _bool,
|
|
record_context_cpp: _bool,
|
|
alloc_trace_max_entries: _int,
|
|
alloc_trace_record_context: _bool,
|
|
clear_history: _bool,
|
|
compile_context: _bool,
|
|
global_record_annotations: _bool,
|
|
) -> None: ...
|
|
def _cuda_record_memory_history(
|
|
enabled: str | None,
|
|
context: str | None,
|
|
stacks: str,
|
|
max_entries: _int,
|
|
clear_history: _bool,
|
|
compile_context: _bool,
|
|
global_record_annotations: _bool,
|
|
) -> None: ...
|
|
def _cuda_isHistoryEnabled() -> _bool: ...
|
|
def _cuda_getAllocatorBackend() -> str: ...
|
|
|
|
class _cuda_CUDAAllocator_AllocatorState: ...
|
|
|
|
def _cuda_getCheckpointState(
|
|
device: _int,
|
|
mempool: tuple[_int, _int],
|
|
) -> _cuda_CUDAAllocator_AllocatorState: ...
|
|
def _set_cached_tensors_enabled(enabled: _bool) -> None: ...
|
|
def _add_cached_tensor(t: Tensor) -> None: ...
|
|
def _remove_cached_tensor(t: Tensor) -> None: ...
|
|
def _tensors_data_ptrs_at_indices_equal(
|
|
tensors: list[Tensor | _int],
|
|
ptrs: list[_int | None],
|
|
indices: list[_int],
|
|
) -> _bool: ...
|
|
def _construct_CUDA_Tensor_From_Storage_And_Metadata(
|
|
metadata: dict,
|
|
storage: Storage,
|
|
) -> Tensor: ...
|
|
def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ...
|
|
def _set_storage_data_ptr_access_error_msg(storage_ptr: _int, s: str) -> None: ...
|
|
def _free_And_Remove_DeleterFn(storage_ptr: _int) -> None: ...
|
|
def _has_Standard_Deleter(storage_ptr: _int) -> _bool: ...
|
|
|
|
class _cuda_CUDAAllocator: ...
|
|
|
|
def _cuda_customAllocator(alloc_fn: _int, free_fn: _int) -> _cuda_CUDAAllocator: ...
|
|
def _cuda_changeCurrentAllocator(allocator: _cuda_CUDAAllocator) -> None: ...
|
|
def _cuda_getAllocator() -> _cuda_CUDAAllocator: ...
|
|
def _cuda_lock_mutex() -> None: ...
|
|
def _cuda_unlock_mutex() -> None: ...
|
|
def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ...
|
|
def _cuda_jiterator_compile_and_launch_kernel(
|
|
code_string: str,
|
|
kernel_name: str,
|
|
return_by_ref: _bool,
|
|
num_outputs: _int,
|
|
tensors: tuple,
|
|
kwargs: dict[str, _int | _float | _bool],
|
|
) -> Tensor: ...
|
|
def _cuda_get_cudnn_benchmark_limit() -> _int: ...
|
|
def _cuda_set_cudnn_benchmark_limit(arg: _int) -> None: ...
|
|
def _cuda_get_conv_benchmark_empty_cache() -> _bool: ...
|
|
def _cudnn_set_conv_benchmark_empty_cache(enable: _bool) -> None: ...
|
|
def _nccl_version() -> _int: ...
|
|
def _nccl_version_suffix() -> bytes: ...
|
|
def _nccl_unique_id() -> bytes: ...
|
|
def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ...
|
|
def _nccl_reduce(
|
|
input: Sequence[Tensor],
|
|
output: Tensor,
|
|
root: _int,
|
|
op: _int,
|
|
streams: Sequence[_CudaStreamBase] | None,
|
|
comms: Sequence[object] | None,
|
|
) -> None: ...
|
|
def _nccl_all_reduce(
|
|
input: Sequence[Tensor],
|
|
output: Sequence[Tensor],
|
|
op: _int,
|
|
streams: Sequence[_CudaStreamBase] | None,
|
|
comms: Sequence[object] | None,
|
|
) -> None: ...
|
|
def _nccl_broadcast(
|
|
input: Sequence[Tensor],
|
|
root: _int,
|
|
streams: Sequence[_CudaStreamBase] | None,
|
|
comms: Sequence[object] | None,
|
|
) -> None: ...
|
|
def _nccl_all_gather(
|
|
input: Sequence[Tensor],
|
|
output: Sequence[Tensor],
|
|
streams: Sequence[_CudaStreamBase] | None,
|
|
comms: Sequence[object] | None,
|
|
) -> None: ...
|
|
def _nccl_reduce_scatter(
|
|
input: Sequence[Tensor],
|
|
output: Sequence[Tensor],
|
|
op: _int,
|
|
streams: Sequence[_CudaStreamBase] | None,
|
|
comms: Sequence[object] | None,
|
|
) -> None: ...
|
|
def _rocm_is_backward_pass() -> _bool: ...
|
|
def _cuda_tunableop_enable(val: _bool) -> None: ...
|
|
def _cuda_tunableop_is_enabled() -> _bool: ...
|
|
def _cuda_tunableop_tuning_enable(val: _bool) -> None: ...
|
|
def _cuda_tunableop_tuning_is_enabled() -> _bool: ...
|
|
def _cuda_tunableop_set_max_tuning_duration(duration: _int) -> None: ...
|
|
def _cuda_tunableop_get_max_tuning_duration() -> _int: ...
|
|
def _cuda_tunableop_set_max_tuning_iterations(iterations: _int) -> None: ...
|
|
def _cuda_tunableop_get_max_tuning_iterations() -> _int: ...
|
|
def _cuda_tunableop_set_filename(
|
|
filename: str,
|
|
insert_device_ordinal: _bool | None,
|
|
) -> None: ...
|
|
def _cuda_tunableop_get_filename() -> str: ...
|
|
def _cuda_tunableop_read_file(filename: str | None) -> _bool: ...
|
|
def _cuda_tunableop_get_results() -> tuple[str, str, str, _float]: ...
|
|
def _cuda_tunableop_get_validators() -> tuple[str, str]: ...
|
|
def _cuda_tunableop_set_rotating_buffer_size(buffer_size: _int) -> None: ...
|
|
def _cuda_tunableop_get_rotation_buffer_size() -> _int: ...
|
|
def _cuda_tunableop_set_numerical_check_tolerances(
|
|
enabled: _bool, atol: _float = 1e-5, rtol: _float = 1e-5
|
|
) -> None: ...
|
|
|
|
class _CudaDeviceProperties:
|
|
name: str
|
|
major: _int
|
|
minor: _int
|
|
multi_processor_count: _int
|
|
total_memory: _int
|
|
is_integrated: _int
|
|
is_multi_gpu_board: _int
|
|
max_threads_per_multi_processor: _int
|
|
gcnArchName: str
|
|
warp_size: _int
|
|
uuid: str
|
|
L2_cache_size: _int
|
|
clock_rate: _int
|
|
memory_clock_rate: _int
|
|
memory_bus_width: _int
|
|
|
|
# Functions related to SDPA
|
|
class _SDPAParams:
|
|
query: Tensor
|
|
key: Tensor
|
|
value: Tensor
|
|
attn_mask: Tensor | None
|
|
dropout: _float
|
|
is_causal: _bool
|
|
enable_gqa: _bool
|
|
def __init__(
|
|
self,
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
attn_mask: Tensor | None,
|
|
dropout: _float,
|
|
is_causal: _bool,
|
|
enable_gqa: _bool,
|
|
) -> None: ...
|
|
|
|
class _SDPBackend(Enum):
|
|
ERROR = -1
|
|
MATH = 0
|
|
FLASH_ATTENTION = 1
|
|
EFFICIENT_ATTENTION = 2
|
|
CUDNN_ATTENTION = 3
|
|
OVERRIDEABLE = 4
|
|
|
|
def _is_flash_attention_available() -> _bool: ...
|
|
def _can_use_cudnn_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
|
|
def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
|
|
def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
|
|
|
|
# Defined in torch/csrc/cuda/GdsFile.cpp
|
|
def _gds_register_buffer(t: Storage) -> None: ...
|
|
def _gds_deregister_buffer(t: Storage) -> None: ...
|
|
def _gds_register_handle(fd: _int) -> _int: ...
|
|
def _gds_deregister_handle(handle: _int) -> None: ...
|
|
def _gds_load_storage(handle: _int, s: Storage, offset: _int) -> None: ...
|
|
def _gds_save_storage(handle: _int, s: Storage, offset: _int) -> None: ...
|
|
|
|
# Defined in torch/csrc/cuda/python_comm.cpp
|
|
def _broadcast(tensor: Tensor, devices: list[_int]) -> list[Tensor]: ...
|
|
def _broadcast_out(tensor: Tensor, out_tensors: list[Tensor]) -> list[Tensor]: ...
|
|
def _broadcast_coalesced(
|
|
tensors: list[Tensor],
|
|
devices: list[_int],
|
|
buffer_size: _int,
|
|
) -> list[list[Tensor]]: ...
|
|
def _scatter(
|
|
tensor: Tensor,
|
|
devices: list[_int],
|
|
chunk_sizes: list[_int] | None,
|
|
dim: _int,
|
|
streams: list[Stream] | None,
|
|
) -> list[Tensor]: ...
|
|
def _scatter_out(
|
|
tensor: Tensor,
|
|
out_tensors: list[Tensor],
|
|
dim: _int,
|
|
streams: list[Stream] | None,
|
|
) -> list[Tensor]: ...
|
|
def _gather(
|
|
tensors: list[Tensor],
|
|
dim: _int,
|
|
destination_index: _int | None,
|
|
) -> Tensor: ...
|
|
def _gather_out(tensors: list[Tensor], out_tensor: Tensor, dim: _int) -> Tensor: ...
|
|
|
|
# Defined in torch/csrc/cuda/Stream.cpp
|
|
class _CudaStreamBase(Stream):
|
|
stream_id: _int
|
|
device_index: _int
|
|
device_type: _int
|
|
|
|
device: _device
|
|
cuda_stream: _int
|
|
priority: _int
|
|
|
|
def __new__(
|
|
cls,
|
|
priority: _int = 0,
|
|
stream_id: _int = 0,
|
|
device_index: _int = 0,
|
|
stream_ptr: _int = 0,
|
|
) -> Self: ...
|
|
def query(self) -> _bool: ...
|
|
def synchronize(self) -> None: ...
|
|
def priority_range(self) -> tuple[_int, _int]: ...
|
|
|
|
# Defined in torch/csrc/cuda/Event.cpp
|
|
class _CudaEventBase:
|
|
device: _device
|
|
cuda_event: _int
|
|
|
|
def __new__(
|
|
cls,
|
|
enable_timing: _bool = False,
|
|
blocking: _bool = False,
|
|
interprocess: _bool = False,
|
|
external: _bool = False,
|
|
) -> Self: ...
|
|
@classmethod
|
|
def from_ipc_handle(cls, device: _device, ipc_handle: bytes) -> _CudaEventBase: ...
|
|
def record(self, stream: _CudaStreamBase) -> None: ...
|
|
def wait(self, stream: _CudaStreamBase) -> None: ...
|
|
def query(self) -> _bool: ...
|
|
def elapsed_time(self, other: _CudaEventBase) -> _float: ...
|
|
def synchronize(self) -> None: ...
|
|
def ipc_handle(self) -> bytes: ...
|
|
|
|
# Defined in torch/csrc/cuda/Graph.cpp
|
|
class _CUDAGraph:
|
|
def __new__(cls, keep_graph: _bool = ...) -> Self: ...
|
|
def capture_begin(
|
|
self,
|
|
pool: _POOL_HANDLE | None = ...,
|
|
capture_error_mode: str = "global",
|
|
) -> None: ...
|
|
def capture_end(self) -> None: ...
|
|
def instantiate(self) -> None: ...
|
|
def register_generator_state(self, Generator) -> None: ...
|
|
def replay(self) -> None: ...
|
|
def reset(self) -> None: ...
|
|
def pool(self) -> _POOL_HANDLE: ...
|
|
def enable_debug_mode(self) -> None: ...
|
|
def debug_dump(self, debug_path: str) -> None: ...
|
|
def raw_cuda_graph(self) -> _int: ...
|
|
def raw_cuda_graph_exec(self) -> _int: ...
|
|
|
|
# Defined in torch/csrc/cuda/MemPool.cpp
|
|
class _MemPool:
|
|
def __init__(
|
|
self,
|
|
allocator: _cuda_CUDAAllocator | None = None,
|
|
is_user_created: _bool = True,
|
|
use_on_oom: _bool = False,
|
|
) -> None: ...
|
|
@property
|
|
def id(self) -> tuple[_int, _int]: ...
|
|
@property
|
|
def allocator(self) -> _cuda_CUDAAllocator | None: ...
|
|
def use_count(self) -> _int: ...
|
|
|
|
def _cuda_isCurrentStreamCapturing() -> _bool: ...
|
|
def _graph_pool_handle() -> tuple[_int, _int]: ...
|
|
|
|
# Defined in torch/csrc/xpu/Module.cpp
|
|
def _xpu_setDevice(device: _int) -> None: ...
|
|
def _xpu_exchangeDevice(device: _int) -> _int: ...
|
|
def _xpu_maybeExchangeDevice(device: _int) -> _int: ...
|
|
def _xpu_getDevice() -> _int: ...
|
|
def _xpu_getDeviceCount() -> _int: ...
|
|
def _xpu_getArchFlags() -> str | None: ...
|
|
def _xpu_init() -> None: ...
|
|
def _xpu_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
|
|
def _xpu_getCurrentStream(device: _int) -> tuple: ...
|
|
def _xpu_getCurrentRawStream(device: _int) -> _int: ...
|
|
def _xpu_getStreamFromExternal(data_ptr: _int, device_index: _int) -> tuple: ...
|
|
def _xpu_synchronize(device: _int) -> None: ...
|
|
def _xpu_emptyCache() -> None: ...
|
|
def _xpu_memoryStats(device: _int) -> dict[str, Any]: ...
|
|
def _xpu_resetAccumulatedMemoryStats(device: _int) -> None: ...
|
|
def _xpu_resetPeakMemoryStats(device: _int) -> None: ...
|
|
def _xpu_getMemoryInfo(device: _int) -> tuple[_int, _int]: ...
|
|
def _xpu_canDeviceAccessPeer(device: _int, peer: _int) -> _bool: ...
|
|
|
|
class _XpuDeviceProperties:
|
|
name: str
|
|
platform_name: str
|
|
vendor: str
|
|
device_id: _int
|
|
driver_version: str
|
|
version: str
|
|
max_compute_units: _int
|
|
gpu_eu_count: _int
|
|
max_work_group_size: _int
|
|
max_num_sub_groups: _int
|
|
sub_group_sizes: list[_int]
|
|
has_fp16: _bool
|
|
has_fp64: _bool
|
|
has_atomic64: _bool
|
|
has_bfloat16_conversions: _bool
|
|
has_subgroup_matrix_multiply_accumulate: _bool
|
|
has_subgroup_matrix_multiply_accumulate_tensor_float32: _bool
|
|
has_subgroup_2d_block_io: _bool
|
|
total_memory: _int
|
|
gpu_subslice_count: _int
|
|
architecture: _int
|
|
type: str
|
|
uuid: Any
|
|
|
|
# Defined in torch/csrc/xpu/Stream.cpp
|
|
class _XpuStreamBase(Stream):
|
|
stream_id: _int
|
|
device_index: _int
|
|
device_type: _int
|
|
|
|
device: _device
|
|
sycl_queue: _int
|
|
priority: _int
|
|
|
|
def __new__(
|
|
cls,
|
|
priority: _int = 0,
|
|
stream_id: _int = 0,
|
|
device_index: _int = 0,
|
|
device_type: _int = 0,
|
|
) -> Self: ...
|
|
def query(self) -> _bool: ...
|
|
def synchronize(self) -> None: ...
|
|
@staticmethod
|
|
def priority_range() -> tuple: ...
|
|
|
|
# Defined in torch/csrc/xpu/Event.cpp
|
|
class _XpuEventBase:
|
|
device: _device
|
|
sycl_event: _int
|
|
|
|
def __new__(cls, enable_timing: _bool = False) -> Self: ...
|
|
def record(self, stream: _XpuEventBase) -> None: ...
|
|
def wait(self, stream: _XpuStreamBase) -> None: ...
|
|
def query(self) -> _bool: ...
|
|
def elapsed_time(self, other: _XpuEventBase) -> _float: ...
|
|
def synchronize(self) -> None: ...
|
|
|
|
# Defined in torch/csrc/DataLoader.cpp
|
|
def _set_worker_signal_handlers(
|
|
*arg: Any,
|
|
) -> None: ... # THPModule_setWorkerSignalHandlers
|
|
def _set_worker_pids(
|
|
key: _int,
|
|
child_pids: tuple[_int, ...],
|
|
) -> None: ... # THPModule_setWorkerPIDs
|
|
def _remove_worker_pids(loader_id: _int) -> None: ... # THPModule_removeWorkerPIDs
|
|
def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails
|
|
|
|
# Defined in torch/csrc/DeviceAccelerator.cpp
|
|
def _accelerator_getAccelerator() -> _device: ...
|
|
def _accelerator_setDeviceIndex(device_index: _int) -> None: ...
|
|
def _accelerator_getDeviceIndex() -> _int: ...
|
|
def _accelerator_setStream(Stream) -> None: ...
|
|
def _accelerator_getStream(device_index: _int) -> Stream: ...
|
|
def _accelerator_synchronizeDevice(device_index: _int) -> None: ...
|
|
def _accelerator_exchangeDevice(device_index: _int) -> _int: ...
|
|
def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ...
|
|
def _accelerator_isAllocatorInitialized() -> _bool: ...
|
|
def _accelerator_emptyCache() -> None: ...
|
|
def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ...
|
|
def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ...
|
|
def _accelerator_resetPeakStats(device_index: _int) -> None: ...
|
|
|
|
# Defined in torch/csrc/jit/python/python_tracer.cpp
|
|
class TracingState:
|
|
def push_scope(self, scope_name: str) -> None: ...
|
|
def pop_scope(self) -> None: ...
|
|
def current_scope(self) -> str: ...
|
|
def set_graph(self, graph: Graph) -> None: ...
|
|
def graph(self) -> Graph: ...
|
|
|
|
def _create_graph_by_tracing(
|
|
func: Callable[..., Any],
|
|
inputs: Any,
|
|
var_name_lookup_fn: Callable[[Tensor], str],
|
|
strict: Any,
|
|
force_outplace: Any,
|
|
self: Any = None,
|
|
argument_names: list[str] = ...,
|
|
) -> tuple[Graph, Stack]: ...
|
|
def _tracer_warn_use_python(): ...
|
|
def _get_tracing_state() -> TracingState: ...
|
|
|
|
# Defined in torch/csrc/jit/python/python_ir.cpp
|
|
# Not actually defined in python_ir.cpp, not sure where they are.
|
|
class IValue: ...
|
|
|
|
Stack: TypeAlias = list[IValue]
|
|
|
|
class JitType:
|
|
annotation_str: str
|
|
def isSubtypeOf(self, other: JitType) -> _bool: ...
|
|
def with_dtype(self, dtype: _dtype) -> JitType: ...
|
|
def with_sizes(self, sizes: list[_int | None]) -> JitType: ...
|
|
def kind(self) -> str: ...
|
|
def scalarType(self) -> str | None: ...
|
|
def getElementType(self) -> JitType: ...
|
|
def dtype(self) -> _dtype | None: ...
|
|
|
|
class InferredType:
|
|
def __init__(self, arg: JitType | str) -> None: ...
|
|
def type(self) -> JitType: ...
|
|
def success(self) -> _bool: ...
|
|
def reason(self) -> str: ...
|
|
|
|
class Type(JitType):
|
|
def str(self) -> _str: ...
|
|
def containedTypes(self) -> list[JitType]: ...
|
|
def dim(self) -> _int | None: ...
|
|
def undefined(self) -> _bool | None: ...
|
|
def sizes(self) -> list[_int] | None: ...
|
|
def symbol_sizes(self) -> list[_int] | None: ...
|
|
def varyingSizes(self) -> list[_int | None] | None: ...
|
|
def strides(self) -> list[_int] | None: ...
|
|
def contiguous(self) -> Self: ...
|
|
def device(self) -> _device | None: ...
|
|
def is_interface_type(self) -> _bool: ...
|
|
def requires_grad(self) -> _bool: ...
|
|
@property
|
|
def annotation_string(self) -> _str: ...
|
|
|
|
class AnyType(JitType):
|
|
@staticmethod
|
|
def get() -> AnyType: ...
|
|
|
|
class NoneType(JitType):
|
|
@staticmethod
|
|
def get() -> NoneType: ...
|
|
|
|
class BoolType(JitType):
|
|
@staticmethod
|
|
def get() -> BoolType: ...
|
|
|
|
class FloatType(JitType):
|
|
@staticmethod
|
|
def get() -> FloatType: ...
|
|
|
|
class ComplexType(JitType):
|
|
@staticmethod
|
|
def get() -> ComplexType: ...
|
|
|
|
class IntType(JitType):
|
|
@staticmethod
|
|
def get() -> IntType: ...
|
|
|
|
class SymIntType(JitType):
|
|
@staticmethod
|
|
def get() -> SymIntType: ...
|
|
|
|
class SymBoolType(JitType):
|
|
@staticmethod
|
|
def get() -> SymBoolType: ...
|
|
|
|
class NumberType(JitType):
|
|
@staticmethod
|
|
def get() -> NumberType: ...
|
|
|
|
class StringType(JitType):
|
|
@staticmethod
|
|
def get() -> StringType: ...
|
|
|
|
class DeviceObjType(JitType):
|
|
@staticmethod
|
|
def get() -> DeviceObjType: ...
|
|
|
|
class _GeneratorType(JitType):
|
|
@staticmethod
|
|
def get() -> _GeneratorType: ...
|
|
|
|
class StreamObjType(JitType):
|
|
@staticmethod
|
|
def get() -> StreamObjType: ...
|
|
|
|
class ListType(JitType):
|
|
def __init__(self, a: JitType) -> None: ...
|
|
def getElementType(self) -> JitType: ...
|
|
@staticmethod
|
|
def ofInts() -> ListType: ...
|
|
@staticmethod
|
|
def ofTensors() -> ListType: ...
|
|
@staticmethod
|
|
def ofFloats() -> ListType: ...
|
|
@staticmethod
|
|
def ofComplexDoubles() -> ListType: ...
|
|
@staticmethod
|
|
def ofBools() -> ListType: ...
|
|
@staticmethod
|
|
def ofStrings() -> ListType: ...
|
|
|
|
class DictType(JitType):
|
|
def __init__(self, key: JitType, value: JitType) -> None: ...
|
|
def getKeyType(self) -> JitType: ...
|
|
def getValueType(self) -> JitType: ...
|
|
|
|
class TupleType(JitType):
|
|
def __init__(self, a: list[JitType | None]) -> None: ...
|
|
def elements(self) -> list[JitType]: ...
|
|
|
|
class UnionType(JitType):
|
|
def __init__(self, a: list[JitType]) -> None: ...
|
|
|
|
class ClassType(JitType):
|
|
def __init__(self, qualified_name: str) -> None: ...
|
|
def qualified_name(self) -> str: ...
|
|
|
|
class InterfaceType(JitType):
|
|
def __init__(self, qualified_name: str) -> None: ...
|
|
def getMethod(self, name: str) -> FunctionSchema | None: ...
|
|
def getMethodNames(self) -> list[str]: ...
|
|
|
|
JitTypeT = TypeVar("JitTypeT", bound=JitType) # noqa: PYI001
|
|
|
|
class OptionalType(JitType, Generic[JitTypeT]):
|
|
def __init__(self, a: JitTypeT) -> None: ...
|
|
def getElementType(self) -> JitTypeT: ...
|
|
@staticmethod
|
|
def ofTensor() -> OptionalType: ...
|
|
|
|
class FutureType(JitType):
|
|
def __init__(self, a: JitType) -> None: ...
|
|
def getElementType(self) -> JitType: ...
|
|
|
|
class AwaitType(JitType):
|
|
def __init__(self, a: JitType) -> None: ...
|
|
def getElementType(self) -> JitType: ...
|
|
|
|
class RRefType(JitType):
|
|
def __init__(self, a: JitType) -> None: ...
|
|
|
|
class EnumType(JitType):
|
|
def __init__(
|
|
self,
|
|
qualified_name: str,
|
|
value_type: JitType,
|
|
enum_names_values: list[Any],
|
|
) -> None: ...
|
|
|
|
class TensorType(JitType):
|
|
@classmethod
|
|
def get(cls) -> TensorType: ...
|
|
@classmethod
|
|
def getInferred(cls) -> TensorType: ...
|
|
def with_sizes(self, other: list[_int | None] | None) -> TensorType: ...
|
|
def sizes(self) -> list[_int] | None: ...
|
|
def varyingSizes(self) -> list[_int | None] | None: ...
|
|
def strides(self) -> list[_int] | None: ...
|
|
def device(self) -> _device | None: ...
|
|
def dim(self) -> _int: ...
|
|
def dtype(self) -> _dtype | None: ...
|
|
@staticmethod
|
|
def create_from_tensor(t: Tensor) -> TensorType: ...
|
|
|
|
# Defined in torch/csrc/jit/python/python_tree_views.cpp
|
|
class SourceRange: ...
|
|
class TreeView: ...
|
|
|
|
class Ident(TreeView):
|
|
@property
|
|
def name(self) -> str: ...
|
|
|
|
class ClassDef(TreeView): ...
|
|
|
|
class Def(TreeView):
|
|
def name(self) -> Ident: ...
|
|
|
|
class Decl(TreeView): ...
|
|
|
|
# Defined in torch/csrc/distributed/rpc/init.cpp
|
|
def _rpc_init() -> _bool: ...
|
|
|
|
# Defined in torch/csrc/distributed/autograd/init.cpp
|
|
def _dist_autograd_init() -> _bool: ...
|
|
|
|
# Defined in torch/csrc/distributed/c10d/init.cpp
|
|
def _c10d_init() -> _bool: ...
|
|
|
|
# Defined in torch/csrc/distributed/rpc/testing/init.cpp
|
|
def _faulty_agent_init() -> _bool: ...
|
|
def _register_py_class_for_device(device: str, cls: Any) -> None: ...
|
|
|
|
# Defined in torch/csrc/Module.cpp
|
|
def _current_graph_task_id() -> _int: ...
|
|
def _current_autograd_node() -> _Node: ...
|
|
def _will_engine_execute_node(node: _Node) -> _bool: ...
|
|
def _dispatch_key_set(tensor) -> str: ...
|
|
|
|
# Defined in torch/csrc/Exceptions.cpp
|
|
class AcceleratorError(RuntimeError): ...
|
|
class OutOfMemoryError(RuntimeError): ...
|
|
class _DistError(RuntimeError): ...
|
|
class _DistBackendError(RuntimeError): ...
|
|
class _DistStoreError(RuntimeError): ...
|
|
class _DistNetworkError(RuntimeError): ...
|
|
class _DistQueueEmptyError(_DistStoreError): ...
|
|
|
|
# Defined in torch/csrc/profiler/init.cpp
|
|
class CapturedTraceback: ...
|
|
|
|
def gather_traceback(python: _bool, script: _bool, cpp: _bool) -> CapturedTraceback: ...
|
|
def symbolize_tracebacks(
|
|
tracebacks: list[CapturedTraceback],
|
|
) -> list[dict[str, Any]]: ...
|
|
def _load_mobile_module_from_file(filename: str): ...
|
|
def _load_mobile_module_from_bytes(bytes_: bytes): ...
|
|
def _load_jit_module_from_file(filename: str): ...
|
|
def _load_jit_module_from_bytes(bytes_: bytes): ...
|
|
def _save_mobile_module(m: LiteScriptModule, filename: str): ...
|
|
def _save_jit_module(m: ScriptModule, filename: str, extra_files: dict[str, Any]): ...
|
|
def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ...
|
|
def _save_jit_module_to_bytes(
|
|
m: ScriptModule,
|
|
extra_files: dict[str, Any],
|
|
) -> bytes: ...
|
|
def _get_module_info_from_flatbuffer(data: bytes): ...
|
|
def _jit_resolve_packet(op_name: str, *args, **kwargs) -> str: ...
|
|
def _swap_tensor_impl(t1: Tensor, t2: Tensor): ...
|
|
def _pickle_save(obj: Any) -> bytes: ...
|
|
def _pickle_load_obj(bs: bytes) -> Any: ...
|
|
|
|
# Defined in torch/csrc/jit/runtime/static/init.cpp
|
|
def _jit_to_static_module(graph_or_module: Graph | ScriptModule) -> Any: ...
|
|
def _fuse_to_static_module(
|
|
graph_or_module: Graph | ScriptModule,
|
|
min_size: _int,
|
|
) -> Any: ...
|
|
|
|
# Defined in torch/csrc/fx/node.cpp
|
|
def _fx_map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any: ...
|
|
def _fx_map_arg(a: Any, fn: Callable[[Any], Any]) -> Any: ...
|
|
|
|
class _NodeBase:
|
|
_erased: _bool
|
|
_prev: FxNode
|
|
_next: FxNode
|
|
def __init__(
|
|
self,
|
|
graph: Any,
|
|
name: str,
|
|
op: str,
|
|
target: Any,
|
|
return_type: Any,
|
|
) -> None: ...
|
|
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
|
|
|
class _NodeIter(Iterator[FxNode]):
|
|
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
|
|
def __iter__(self) -> Self: ...
|
|
def __next__(self) -> FxNode: ...
|
|
|
|
# Defined in torch/csrc/inductor/static_cuda_launcher.cpp
|
|
class _StaticCudaLauncher:
|
|
@staticmethod
|
|
def _load_kernel(
|
|
cubin_file: str,
|
|
func_name: str,
|
|
shared_mem_bytes: _int,
|
|
device: _int,
|
|
) -> tuple[_int, _int, _int]: ...
|
|
@staticmethod
|
|
def _launch_kernel(
|
|
func: _int,
|
|
grid_x: _int,
|
|
grid_y: _int,
|
|
grid_z: _int,
|
|
num_warps: _int,
|
|
shared_mem_bytes: _int,
|
|
arg_types: str,
|
|
args: tuple[Any, ...],
|
|
stream: _int,
|
|
) -> None: ...
|