mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Optimization Blocklist will be used in a future diff (D40315730) to make the rewrite to transfer input/output backends optional Differential Revision: [D40315729](https://our.internmc.facebook.com/intern/diff/D40315729/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87431 Approved by: https://github.com/mcr229, https://github.com/digantdesai
1505 lines
60 KiB
Python
1505 lines
60 KiB
Python
# ${generated_comment}
|
|
|
|
import torch
|
|
from torch.package import PackageExporter
|
|
from torch import Tensor
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import (
|
|
Any, BinaryIO, Callable, ContextManager, Dict, Iterable, Iterator, List,
|
|
NamedTuple, Optional, overload, Sequence, Tuple, TypeVar, Type, Union,
|
|
Generic, Set, AnyStr)
|
|
from typing_extensions import Literal
|
|
from torch._six import inf
|
|
|
|
from torch.types import (
|
|
_int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage, SymInt, _dispatchkey
|
|
)
|
|
from torch.storage import TypedStorage
|
|
|
|
import builtins
|
|
|
|
# This module is defined in torch/csrc/Module.cpp
|
|
|
|
from . import _nn as _nn
|
|
from . import _onnx as _onnx
|
|
from . import _VariableFunctions as _VariableFunctions
|
|
from . import _functorch as _functorch
|
|
from . import _lazy as _lazy
|
|
from . import _lazy_ts_backend as _lazy_ts_backend
|
|
|
|
T = TypeVar('T')
|
|
S = TypeVar("S", bound="torch.Tensor")
|
|
|
|
# Defined in torch/csrc/Device.cpp
|
|
class device:
|
|
type: str # THPDevice_type
|
|
index: _int # THPDevice_index
|
|
|
|
def __get__(self, instance, owner=None) -> device: ...
|
|
|
|
# THPDevice_pynew
|
|
@overload
|
|
def __init__(self, device: Union[_device, _int, str]) -> None: ...
|
|
|
|
@overload
|
|
def __init__(self, type: str, index: _int) -> None: ...
|
|
|
|
def __reduce__(self) -> Tuple[Any, ...]: ... # THPDevice_reduce
|
|
|
|
# Defined in torch/csrc/Stream.cpp
|
|
class Stream:
|
|
_cdata: _int # Stream handle
|
|
device: device # The device of the stream
|
|
|
|
...
|
|
|
|
# Defined in torch/csrc/Size.cpp
|
|
class Size(Tuple[_int, ...]):
|
|
# TODO: __reduce__
|
|
|
|
@overload # type: ignore[override]
|
|
def __getitem__(self: Size, key: _int) -> _int: ...
|
|
|
|
@overload
|
|
def __getitem__(self: Size, key: slice) -> Size: ...
|
|
|
|
def numel(self: Size) -> _int: ...
|
|
|
|
...
|
|
|
|
# Defined in torch/csrc/Dtype.cpp
|
|
class dtype:
|
|
# TODO: __reduce__
|
|
is_floating_point: _bool
|
|
is_complex: _bool
|
|
is_signed: _bool
|
|
...
|
|
|
|
# Defined in torch/csrc/TypeInfo.cpp
|
|
class iinfo:
|
|
bits: _int
|
|
min: _int
|
|
max: _int
|
|
dtype: str
|
|
|
|
def __init__(self, dtype: _dtype) -> None: ...
|
|
|
|
class finfo:
|
|
bits: _int
|
|
min: _float
|
|
max: _float
|
|
eps: _float
|
|
tiny: _float
|
|
smallest_normal: _float
|
|
resolution: _float
|
|
dtype: str
|
|
|
|
@overload
|
|
def __init__(self, dtype: _dtype) -> None: ...
|
|
|
|
@overload
|
|
def __init__(self) -> None: ...
|
|
|
|
${dtype_class_hints}
|
|
|
|
# Defined in torch/csrc/Layout.cpp
|
|
class layout:
|
|
...
|
|
|
|
# Defined in torch/csrc/utils/disable_torch_function.cpp
|
|
def DisableTorchFunction(): ...
|
|
|
|
# Defined in torch/csrc/utils/tensor_layouts.cpp
|
|
strided : layout = ...
|
|
sparse_coo : layout = ...
|
|
sparse_csr : layout = ...
|
|
sparse_csc : layout = ...
|
|
sparse_bsr : layout = ...
|
|
sparse_bsc : layout = ...
|
|
_mkldnn : layout = ...
|
|
|
|
# Defined in torch/csrc/MemoryFormat.cpp
|
|
class memory_format: ...
|
|
|
|
# Defined in torch/csrc/utils/tensor_memoryformats.cpp
|
|
contiguous_format: memory_format = ...
|
|
channels_last: memory_format = ...
|
|
channels_last_3d: memory_format = ...
|
|
preserve_format: memory_format = ...
|
|
|
|
# Defined in torch/csrc/QScheme.cpp
|
|
class qscheme: ...
|
|
|
|
# Defined in torch/csrc/utils/tensor_qschemes.h
|
|
per_tensor_affine: qscheme = ...
|
|
per_channel_affine: qscheme = ...
|
|
per_tensor_symmetric: qscheme = ...
|
|
per_channel_symmetric: qscheme = ...
|
|
per_channel_affine_float_qparams: qscheme = ...
|
|
|
|
# Defined in torch/csrc/autograd/python_function.cpp
|
|
class _FunctionBase(object):
|
|
...
|
|
|
|
# Defined in torch/csrc/autograd/python_legacy_variable.cpp
|
|
class _LegacyVariableBase(object):
|
|
def __init__(
|
|
self,
|
|
data: Optional[Tensor]=...,
|
|
requires_grad: Optional[_bool]=...,
|
|
volatile: Optional[_bool]=...,
|
|
_grad_fn: Optional[_FunctionBase]=...
|
|
) -> None: ...
|
|
|
|
# Defined in torch/csrc/jit/python/init.cpp
|
|
class IODescriptor: ...
|
|
|
|
class JITException: ...
|
|
|
|
class Future(object):
|
|
def __init__(self, devices: List[device]) -> None: ...
|
|
def done(self) -> _bool: ...
|
|
def value(self) -> Any: ...
|
|
def wait(self) -> Any: ...
|
|
def add_done_callback(self, callback: Callable) -> None: ...
|
|
def then(self, callback: Callable) -> Future: ...
|
|
def set_result(self, result: Any) -> None: ...
|
|
def _set_unwrap_func(self, callback: Callable) -> None: ...
|
|
|
|
def _jit_set_num_profiled_runs(num: _size) -> _size: ...
|
|
|
|
# Defined in torch/csrc/jit/passes/mobile_optimizer_type.h
|
|
class MobileOptimizerType:
|
|
...
|
|
|
|
CONV_BN_FUSION: MobileOptimizerType
|
|
INSERT_FOLD_PREPACK_OPS: MobileOptimizerType
|
|
REMOVE_DROPOUT: MobileOptimizerType
|
|
FUSE_ADD_RELU: MobileOptimizerType
|
|
HOIST_CONV_PACKED_PARAMS: MobileOptimizerType
|
|
|
|
def fork(*args: Any, **kwargs: Any) -> Future: ...
|
|
def wait(fut: Future) -> Any: ...
|
|
def _collect_all(futures: List[Future]) -> Future: ...
|
|
def _set_print_stack_traces_on_fatal_signal(print: _bool) -> None: ...
|
|
|
|
def unify_type_list(types: List[JitType]) -> JitType: ...
|
|
def _freeze_module(module: ScriptModule,
|
|
preserved_attrs: List[str] = [],
|
|
freeze_interfaces: _bool = True,
|
|
preserveParameters: _bool = True) -> ScriptModule: ...
|
|
def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ...
|
|
def _jit_pass_optimize_for_inference(module: 'torch.jit.ScriptModule',
|
|
other_methods: List[str] = []) -> None: ...
|
|
def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
|
|
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
|
|
def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ...
|
|
def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ...
|
|
def _jit_pass_concat_frozen_linear(graph: Graph): ...
|
|
def _jit_pass_convert_frozen_ops_to_mkldnn(graph: Graph): ...
|
|
def _jit_pass_transpose_frozen_linear(graph:Graph): ...
|
|
def _jit_pass_remove_dropout(module: 'torch.jit.ScriptModule'): ...
|
|
|
|
def _is_tracing() -> _bool: ...
|
|
def _jit_init() -> _bool: ...
|
|
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
|
|
def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ...
|
|
def _jit_get_operation(op_name: str) -> Tuple[Callable, List[str]]: ...
|
|
def _get_operation_overload(op_name: str, op_overload_name: str) -> Tuple[Callable, Callable, List[Any]]: ...
|
|
def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ...
|
|
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule',
|
|
optimization_blocklist: Set[MobileOptimizerType],
|
|
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
|
|
def _clone_module_with_class(module: 'torch.jit.ScriptModule',
|
|
ignored_methods: List[AnyStr],
|
|
ignored_attributes: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
|
|
def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule',
|
|
optimization_blocklist: Set[MobileOptimizerType],
|
|
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
|
|
def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule',
|
|
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
|
|
def _jit_pass_inline(Graph) -> None: ...
|
|
def _jit_pass_constant_propagation(Graph) -> None: ...
|
|
def _jit_pass_propagate_shapes_on_graph(Graph) -> None: ...
|
|
def _jit_register_decomposition_for_schema(schema: FunctionSchema, Graph) -> None: ...
|
|
def _jit_erase_non_input_shape_information(Graph) -> None: ...
|
|
def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ...
|
|
def _jit_get_all_schemas() -> List[FunctionSchema]: ...
|
|
def _jit_check_alias_annotation(g: Graph, args: Tuple[Any, ...], unqualified_op_name: str): ...
|
|
def _jit_can_fuse_on_cpu() -> _bool: ...
|
|
def _jit_can_fuse_on_gpu() -> _bool: ...
|
|
def _jit_can_fuse_on_cpu_legacy() -> _bool: ...
|
|
def _debug_get_fusion_group_inlining() -> _bool: ...
|
|
def _debug_set_fusion_group_inlining(enable: _bool): ...
|
|
def _jit_texpr_fuser_enabled() -> _bool: ...
|
|
def _jit_nvfuser_enabled() -> _bool: ...
|
|
def _jit_llga_enabled() -> _bool: ...
|
|
def _jit_set_llga_enabled(enable: _bool): ...
|
|
def _llvm_enabled() -> _bool: ...
|
|
def _jit_override_can_fuse_on_cpu(override: _bool): ...
|
|
def _jit_override_can_fuse_on_gpu(override: _bool): ...
|
|
def _jit_override_can_fuse_on_cpu_legacy(override: _bool): ...
|
|
def _jit_set_symbolic_shapes_test_mode(override: _bool): ...
|
|
def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ...
|
|
def _jit_set_texpr_fuser_enabled(enable: _bool): ...
|
|
def _jit_set_te_must_use_llvm_cpu(use_llvm: _bool): ...
|
|
def _jit_set_nvfuser_enabled(enable: _bool) -> _bool: ...
|
|
def _jit_cat_wo_conditionals(optimize_cat: _bool): ...
|
|
def _jit_opt_conditionals(opt_conds: _bool): ...
|
|
def _jit_pass_canonicalize(graph: Graph, keep_unique_names: _bool = True): ...
|
|
def _jit_pass_erase_shape_information(graph: Graph): ...
|
|
def _jit_pass_fold_convbn(module: 'torch.jit.ScriptModule'): ...
|
|
def _jit_pass_insert_observers(module: 'torch.jit.ScriptModule',
|
|
method_name: str,
|
|
qconfig_dict: Dict[str, Any],
|
|
inplace: _bool,
|
|
quant_type: _int): ...
|
|
def _jit_pass_insert_quant_dequant(module: 'torch.jit.ScriptModule',
|
|
method_name: str,
|
|
inplace: _bool,
|
|
debug: _bool,
|
|
quant_type: _int): ...
|
|
def _jit_pass_insert_quant_dequant_for_ondevice_ptq(module: 'torch.jit.ScriptModule',
|
|
method_name: str,
|
|
inplace: _bool,
|
|
debug: _bool,
|
|
quant_type: _int): ...
|
|
def _jit_pass_quant_finalize(module: 'torch.jit.ScriptModule',
|
|
quant_type: _int,
|
|
preserved_attrs: Sequence[str]): ...
|
|
def _jit_pass_quant_finalize_for_ondevice_ptq(module: 'torch.jit.ScriptModule',
|
|
quant_type: _int,
|
|
method_name: str): ...
|
|
def _jit_pass_insert_observer_method_for_ondevice_ptq(module: 'torch.jit.ScriptModule',
|
|
method_name: str,
|
|
qconfig_dict: Dict[str, Any],
|
|
inplace: _bool,
|
|
quant_type: _int): ...
|
|
def _jit_set_profiling_executor(profiling_flag: _bool) -> _bool: ...
|
|
def _jit_set_profiling_mode(profiling_flag: _bool) -> _bool: ...
|
|
def _jit_set_fusion_strategy(strategy: List[Tuple[str, _int]]) -> List[Tuple[str, _int]]: ...
|
|
def _jit_try_infer_type(obj: Any) -> InferredType: ...
|
|
def _jit_get_trigger_value(trigger_name: str) -> _int: ...
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
ResolutionCallback = Callable[[str], Callable[..., Any]]
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
# and torch/csrc/jit/python/init.cpp
|
|
def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ...
|
|
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
|
|
def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
|
|
def _jit_assert_is_instance(obj: Any, type: JitType): ...
|
|
def _jit_clear_class_registry() -> None: ...
|
|
def _jit_set_emit_hooks(ModuleHook: Optional[Callable], FunctionHook: Optional[Callable]) -> None: ...
|
|
def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ...
|
|
def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None]): ...
|
|
def _load_for_lite_interpreter_from_buffer(buffer: BinaryIO, map_location: Union[_device, str, None]): ...
|
|
def _export_operator_list(module: LiteScriptModule): ...
|
|
def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ...
|
|
def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ...
|
|
def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ...
|
|
def _backport_for_mobile(filename_input: Union[str, Path], filename_output: Union[str, Path], to_version: _int) -> None: ...
|
|
def _backport_for_mobile_from_buffer(buffer: BinaryIO, filename_output: Union[str, Path], to_version: _int) -> None: ...
|
|
def _backport_for_mobile_to_buffer(filename_input: Union[str, Path], to_version: _int) -> bytes:...
|
|
def _backport_for_mobile_from_buffer_to_buffer(buffer: BinaryIO, to_version: _int) -> bytes:...
|
|
def _get_model_ops_and_info(filename: Union[str, Path]): ...
|
|
def _get_model_ops_and_info_from_buffer(buffer: BinaryIO): ...
|
|
def _get_mobile_model_contained_types(filename: Union[str, Path]): ...
|
|
def _get_mobile_model_contained_types_from_buffer(buffer: BinaryIO): ...
|
|
def _logging_set_logger(logger: LoggerBase) -> LoggerBase: ...
|
|
def _get_graph_executor_optimize(optimize: Optional[_bool] = None) -> _bool: ...
|
|
def _set_graph_executor_optimize(optimize: _bool): ...
|
|
def _export_opnames(module: ScriptModule) -> List[str]: ...
|
|
def _create_function_from_trace(
|
|
qualname: str,
|
|
func: Callable[..., Any],
|
|
input_tuple: Tuple[Any, ...],
|
|
var_lookup_fn: Callable[[Tensor], str],
|
|
strict: _bool,
|
|
force_outplace: _bool,
|
|
argument_names: List[str]
|
|
) -> Tuple[Graph, Stack]: ...
|
|
def _create_function_from_trace_with_dict(
|
|
qualname: str,
|
|
func: Callable[..., Any],
|
|
input_dict: Dict[str, Any],
|
|
var_lookup_fn: Callable[[Tensor], str],
|
|
strict: _bool,
|
|
force_outplace: _bool,
|
|
argument_names: List[str]
|
|
) -> Tuple[Graph, Stack]: ...
|
|
def _jit_is_script_object(obj: Any) -> _bool: ...
|
|
def _last_executed_optimized_graph() -> Graph: ...
|
|
def parse_type_comment(comment: str) -> Decl: ...
|
|
def _get_upgraders_map_size() -> _int: ...
|
|
def _dump_upgraders_map() -> Dict[str, str]: ...
|
|
def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ...
|
|
def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ...
|
|
def merge_type_from_type_comment(decl: Decl, type_annotation_decl: Decl, is_method: _bool) -> Decl: ...
|
|
def parse_ir(input: str, parse_tensor_constants: _bool) -> Graph: ...
|
|
def parse_schema(schema: str) -> FunctionSchema: ...
|
|
def get_device(input: Tensor) -> _int: ...
|
|
|
|
def _resolve_type_from_object(obj: Any, range: SourceRange, rcb: ResolutionCallback) -> JitType: ...
|
|
def _create_module_with_type(ty: JitType) -> ScriptModule: ...
|
|
def _create_object_with_type(ty: ClassType) -> ScriptObject: ...
|
|
def _run_emit_module_hook(m: ScriptModule): ...
|
|
def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def, new_name: str) -> Def: ...
|
|
|
|
def _jit_pass_lower_all_tuples(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ...
|
|
def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> None: ...
|
|
def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], desc: IODescriptor, onnx_shape_inference: _bool, is_script: _bool) -> None: ...
|
|
def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph, module: Optional[ScriptModule] = None) -> None: ...
|
|
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
|
|
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
|
|
def _jit_pass_peephole(graph: Graph, disable_shape_peepholes: _bool = False) -> None: ...
|
|
def _jit_pass_onnx_autograd_function_process(graph: Graph) -> None: ...
|
|
def _jit_pass_fuse_addmm(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_preprocess(graph: Graph) -> None: ...
|
|
def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_remove_print(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_unpack_quantized_weights(
|
|
graph: Graph,
|
|
paramsDict: Dict[str, IValue],
|
|
caffe2: _bool
|
|
) -> Dict[str, IValue]: ...
|
|
def _jit_pass_onnx_quantization_insert_permutes(
|
|
graph: Graph,
|
|
paramsDict: Dict[str, IValue]
|
|
) -> Dict[str, IValue]: ...
|
|
def _jit_pass_custom_pattern_based_rewrite_graph(pattern: str, fused_node_name: str, graph: Graph) -> None: ...
|
|
def _jit_onnx_list_model_parameters(module: ScriptModule) -> Tuple[ScriptModule, List[IValue]]: ...
|
|
def _jit_pass_erase_number_types(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_lint(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx(graph: Graph, _jit_pass_onnx: _onnx.OperatorExportTypes) -> Graph: ...
|
|
def _jit_pass_onnx_scalar_type_analysis(graph: Graph, lowprecision_cast: _bool, opset_version: _int) -> None: ...
|
|
def _jit_pass_onnx_peephole(graph: Graph, opset_version: _int, fixed_batch_size: _bool) -> None: ...
|
|
def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_function_extraction(graph: Graph, module_names : Set[str], param_names : List[str]) -> Dict[Node, Dict[str, str]]: ...
|
|
def _jit_pass_onnx_clear_scope_records() -> None: ...
|
|
def _jit_pass_onnx_track_scope_attributes(graph: Graph, onnx_attrs: Dict[str, Any]) -> None: ...
|
|
def _jit_is_onnx_log_enabled() -> _bool: ...
|
|
def _jit_set_onnx_log_enabled(enabled: _bool) -> None: ...
|
|
def _jit_set_onnx_log_output_stream(stream_name: str) -> None: ...
|
|
def _jit_onnx_log(*args: Any) -> None: ...
|
|
def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ...
|
|
def _jit_pass_inline_fork_wait(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_deduplicate_initializers(graph: Graph, params_dict: Dict[str, IValue], is_train: _bool) -> Dict[str, IValue]: ...
|
|
def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
|
|
def _jit_pass_onnx_constant_fold(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> Dict[str, IValue]: ...
|
|
def _jit_pass_onnx_eliminate_unused_items(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
|
|
def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ...
|
|
def _jit_pass_filter_non_tensor_arguments(params: Dict[str, IValue]) -> Dict[str, Tensor]: ...
|
|
def _jit_decay_packed_param_input_types(graph: Graph) -> None: ...
|
|
def _jit_pass_onnx_node_shape_type_inference(n: Node, paramsDict: Dict[str, IValue], opset_version: _int) -> None: ...
|
|
def _jit_onnx_convert_pattern_from_subblock(block: Block, n: Node, env: Dict[Value, Value]) -> List[Value]: ...
|
|
def _jit_pass_onnx_block(
|
|
old_block: Block,
|
|
new_block: Block,
|
|
operator_export_type: _onnx.OperatorExportTypes,
|
|
env: Dict[Value, Value],
|
|
is_sub_block: _bool
|
|
) -> Dict[Value, Value]: ...
|
|
def _jit_pass_onnx_assign_scoped_names_for_node_and_value(graph: Graph) -> None: ...
|
|
def _jit_pass_fixup_onnx_controlflow_node(n: Node, opset_version: _int) -> List[Value]: ...
|
|
def _jit_onnx_create_full_scope_name(class_name: str, variable_name: str) -> str: ...
|
|
|
|
def _compile_graph_to_code_table(name: str, graph: Graph) -> IValue: ...
|
|
|
|
def _generate_upgraders_graph() -> Dict[str, Graph]: ...
|
|
|
|
def _calculate_package_version_based_on_upgraders(val: _bool): ...
|
|
|
|
def _get_version_calculator_flag() -> _bool: ...
|
|
|
|
def _jit_script_interface_compile(name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool): ...
|
|
def _jit_script_compile_overload(
|
|
qualname: str,
|
|
overload_decl: Decl,
|
|
implementation_def: Def,
|
|
rcb: ResolutionCallback,
|
|
implementation_defaults: Dict[str, Any],
|
|
signature: Any
|
|
): ...
|
|
def _jit_script_compile(
|
|
qual_name: str,
|
|
definition: Def,
|
|
rcb: ResolutionCallback,
|
|
defaults: Dict[str, Any]
|
|
): ...
|
|
def _jit_script_class_compile(
|
|
qual_name: str,
|
|
definition: ClassDef,
|
|
defaults: Dict[str, Dict[str, Any]],
|
|
rcb: ResolutionCallback
|
|
): ...
|
|
def _parse_source_def(src: str) -> Def: ...
|
|
def import_ir_module(
|
|
cu: CompilationUnit,
|
|
filename: Union[str, Path],
|
|
map_location: Union[_device, str, None],
|
|
extra_files: Dict[str, Any]
|
|
) -> ScriptModule: ...
|
|
def import_ir_module_from_buffer(
|
|
cu: CompilationUnit,
|
|
buffer: BinaryIO,
|
|
map_location: Union[_device, str, None],
|
|
extra_files: Dict[str, Any]
|
|
) -> ScriptModule: ...
|
|
def _import_ir_module_from_package(
|
|
cu: CompilationUnit,
|
|
reader: PyTorchFileReader,
|
|
storage_context: DeserializationStorageContext,
|
|
map_location: Union[_device, str, None],
|
|
ts_id: str
|
|
) -> ScriptModule: ...
|
|
|
|
def _assign_output_shapes(graph: Graph, inputs: List[Tensor]) -> Graph: ...
|
|
def _check_onnx_proto(proto: str, full_check: _bool = False) -> None: ...
|
|
def _propagate_and_assign_input_shapes(
|
|
graph: Graph,
|
|
inputs: Tuple[Tensor, ...],
|
|
param_count_list: List[_int],
|
|
with_grad: _bool,
|
|
propagate: _bool
|
|
) -> Graph: ...
|
|
|
|
# Defined in torch/csrc/jit/runtime/graph_executor.h
|
|
class GraphExecutorState:
|
|
...
|
|
|
|
# Defined in torch/torch/csrc/jit/ir/alias_analysis.h
|
|
class AliasDb:
|
|
def __str__(self) -> str: ...
|
|
...
|
|
|
|
class _InsertPoint:
|
|
def __enter__(self) -> None: ...
|
|
def __exit__(self, *args) -> None: ...
|
|
|
|
# Defined in torch/csrc/jit/ir/ir.h
|
|
class Use:
|
|
@property
|
|
def user(self) -> Node: ...
|
|
@property
|
|
def offset(self) -> _int: ...
|
|
def isAfter(self, other: Use) -> _bool: ...
|
|
...
|
|
|
|
# Defined in torch/csrc/jit/ir/ir.h
|
|
class Value:
|
|
def type(self)-> JitType: ...
|
|
def setType(self, t: JitType) -> Value: ...
|
|
def setTypeAs(self, other: Value) -> Value: ...
|
|
def inferTypeFrom(self, t: Tensor) -> None: ...
|
|
def debugName(self) -> str: ...
|
|
def setDebugName(self, name: str) -> None: ...
|
|
def unique(self) -> _int: ...
|
|
def offset(self) -> _int: ...
|
|
def node(self) -> Node: ...
|
|
def uses(self) -> List[Use]: ...
|
|
def replaceAllUsesWith(self, val: Value) -> None: ...
|
|
def replaceAllUsesAfterNodeWith(self, node: Node, val: Value) -> None: ...
|
|
def requires_grad(self) -> _bool: ...
|
|
def requiresGrad(self) -> _bool: ...
|
|
def copyMetadata(self, other: Value) -> Value: ...
|
|
def isCompleteTensor(self) -> _bool: ...
|
|
def toIValue(self) -> IValue: ...
|
|
...
|
|
|
|
# Defined in torch/csrc/jit/ir/ir.h
|
|
class Block:
|
|
def inputs(self) -> Iterator[Value]: ...
|
|
def outputs(self) -> Iterator[Value]: ...
|
|
def nodes(self) -> Iterator[Node]: ...
|
|
def paramNode(self) -> Node: ...
|
|
def returnNode(self) -> Node: ...
|
|
def owningNode(self) -> Node: ...
|
|
def registerOutput(self, n: Value) -> _int: ...
|
|
def addNode(self, name: str, inputs: Sequence[Value]) -> Node: ...
|
|
...
|
|
|
|
# Defined in torch/csrc/jit/ir/ir.h
|
|
class Node:
|
|
def __getitem__(self, key: str) -> Any: ...
|
|
def schema(self) -> str: ...
|
|
def input(self) -> Value: ...
|
|
def inputs(self) -> Iterator[Value]: ...
|
|
def inputsAt(self, idx: _int) -> Value: ...
|
|
def inputsSize(self) -> _int: ...
|
|
def output(self) -> Value: ...
|
|
def outputs(self) -> Iterator[Value]: ...
|
|
def outputsAt(self, idx: _int) -> Value: ...
|
|
def outputsSize(self) -> _int: ...
|
|
def hasMultipleOutputs(self) -> _bool: ...
|
|
def blocks(self) -> List[Block]: ...
|
|
def addBlock(self) -> Block: ...
|
|
def mustBeNone(self) -> _bool: ...
|
|
def matches(self, pattern: str) -> _bool: ...
|
|
def kind(self) -> str: ...
|
|
def kindOf(self, name: str) -> str: ...
|
|
def addInput(self, name: str) -> Value: ...
|
|
def replaceInput(self, i: _int, newValue: Value) -> Value: ...
|
|
def replaceInputWith(self, from_: Value, to: Value) -> None: ...
|
|
def replaceAllUsesWith(self, n: Node) -> None: ...
|
|
def insertBefore(self, n: Node) -> Node: ...
|
|
def insertAfter(self, n: Node) -> Node: ...
|
|
def isBefore(self, n: Node) -> _bool: ...
|
|
def isAfter(self, n: Node) -> _bool: ...
|
|
def moveBefore(self, n: Node) -> None: ...
|
|
def moveAfter(self, n: Node) -> None: ...
|
|
def removeInput(self, i: _int) -> None: ...
|
|
def removeAllInputs(self, i: _int) -> None: ...
|
|
def hasUses(self) -> _bool: ...
|
|
def eraseOutput(self, i: _int) -> None: ...
|
|
def addOutput(self) -> Value: ...
|
|
def scopeName(self) -> str: ...
|
|
def isNondeterministic(self) -> _bool: ...
|
|
def copyAttributes(self, rhs: Node) -> Node: ...
|
|
def copyMetadata(self, rhs: Node) -> Node: ...
|
|
def hasAttributes(self) -> _bool: ...
|
|
def hasAttribute(self, name: str) -> _bool: ...
|
|
def removeAttribute(self, attr: str) -> Node: ...
|
|
def namedInput(self, name: str) -> Value: ...
|
|
def sourceRange(self) -> SourceRange: ...
|
|
def owningBlock(self) -> Block: ...
|
|
def findNode(self, kind: str, recurse: _bool = True) -> Node: ...
|
|
def findAllNodes(self, kind: str, recurse: _bool = True) -> List[Node]: ...
|
|
def getModuleHierarchy(self) -> str: ...
|
|
def prev(self) -> Node: ...
|
|
def destroy(self) -> None: ...
|
|
def attributeNames(self) -> List[str]: ...
|
|
|
|
# Accessors for attributes as types.
|
|
def f(self, name: str) -> _float: ...
|
|
def f_(self, name: str, val: _float) -> Node: ...
|
|
def fs(self, name: str) -> List[_float]: ...
|
|
def fs_(self, name: str, val: List[_float]) -> Node: ...
|
|
def c(self, name: str) -> complex: ...
|
|
def c_(self, name: str, val: complex) -> Node: ...
|
|
def s(self, name: str) -> str: ...
|
|
def s_(self, name: str, val: str) -> Node: ...
|
|
def ss(self, name: str) -> List[str]: ...
|
|
def ss_(self, name: str, val: List[str]) -> Node: ...
|
|
def i(self, name: str) -> _int: ...
|
|
def i_(self, name: str, val: _int) -> Node: ...
|
|
# Cannot define "is" like this because it's a reserved keyword in python.
|
|
# def is(self, name: str) -> List[_int]: ...
|
|
# def is_(self, name: str, val: List[_int]) -> Node: ...
|
|
def g(self, name: str) -> Graph: ...
|
|
def g_(self, name: str, val: Graph) -> Node: ...
|
|
def gs(self, name: str) -> List[Graph]: ...
|
|
def gs_(self, name: str, val: List[Graph]) -> Node: ...
|
|
def ival(self, name: str) -> IValue: ...
|
|
def ival_(self, name: str, val: IValue) -> Node: ...
|
|
def t(self, name: str) -> Tensor: ...
|
|
def t_(self, name: str, val: Tensor) -> Node: ...
|
|
def ts(self, name: str) -> List[Tensor]: ...
|
|
def ts_(self, name: str, val: List[Tensor]) -> Node: ...
|
|
def ty(self, name: str) -> JitType: ...
|
|
def ty_(self, name: str, val: JitType) -> Node: ...
|
|
def tys(self, name: str) -> List[JitType]: ...
|
|
def tys_(self, name: str, val: List[JitType]) -> Node: ...
|
|
...
|
|
|
|
# Defined in torch/torch/csrc/jit/ir/ir.h
|
|
class Graph:
|
|
def inputs(self) -> Iterator[Value]: ...
|
|
def outputs(self) -> Iterator[Value]: ...
|
|
def nodes(self) -> Iterator[Node]: ...
|
|
def param_node(self) -> Node: ...
|
|
def return_node(self) -> Node: ...
|
|
def addInput(self, name: str) -> Value: ...
|
|
def eraseInput(self, i: _int) -> None: ...
|
|
def registerOutput(self, n: Value) -> _int: ...
|
|
def eraseOutput(self, i: _int) -> None: ...
|
|
def create(self, name: str, args, num_outputs: _int) -> Node: ...
|
|
def appendNode(self, n: Node) -> Node: ...
|
|
def prependNode(self, n: Node) -> Node: ...
|
|
def insertNode(self, n: Node) -> Node: ...
|
|
def block(self) -> Block: ...
|
|
def lint(self) -> None: ...
|
|
def alias_db(self) -> AliasDb: ...
|
|
def setInsertPoint(self, n: Union[Block, Node]) -> None: ...
|
|
def insert_point_guard(self, n: Union[Block, Node]) -> _InsertPoint: ...
|
|
def insertPoint(self) -> Node: ...
|
|
def insertGraph(self, callee: Graph, inputs: List[Value]) -> List[Value]: ...
|
|
def makeMultiOutputIntoTuple(self) -> None: ...
|
|
...
|
|
|
|
|
|
# Defined in torch/aten/src/ATen/core/alias_info.h
|
|
class AliasInfo:
|
|
is_write: _bool
|
|
before_set: Set[str]
|
|
after_set: Set[str]
|
|
|
|
|
|
# Defined in torch/aten/src/ATen/core/function_schema.h
|
|
class Argument:
|
|
name: str
|
|
type: JitType
|
|
default_value: Optional[Any]
|
|
def has_default_value(self) -> _bool: ...
|
|
kwarg_only : _bool
|
|
is_out: _bool
|
|
alias_info: Optional[AliasInfo]
|
|
...
|
|
class FunctionSchema:
|
|
arguments: List[Argument]
|
|
returns: List[Argument]
|
|
name: str
|
|
overload_name: str
|
|
...
|
|
|
|
class _UpgraderEntry:
|
|
bumped_at_version: _int
|
|
upgrader_name: str
|
|
old_schema: str
|
|
def __init__(self, bumped_at_version: _int, upgrader_name: str, old_schema: str) -> None: ...
|
|
|
|
class _UpgraderRange:
|
|
min_version: _int
|
|
max_version: _int
|
|
|
|
def _get_max_operator_version() -> _int: ...
|
|
|
|
def _get_operator_version_map() -> Dict[str, List[_UpgraderEntry]]: ...
|
|
|
|
def _get_upgrader_ranges(name: str) -> List[_UpgraderRange]: ...
|
|
|
|
def _test_only_add_entry_to_op_version(op_name: str, entry: _UpgraderEntry) -> None: ...
|
|
|
|
def _test_only_remove_entry_to_op_version(op_name: str) -> None: ...
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
class ScriptModuleSerializer(object):
|
|
def __init__(self, export_writer: PyTorchFileWriter) -> None: ...
|
|
def serialize(self, model: ScriptModule, script_module_id: _int) -> None: ...
|
|
def write_files(self) -> None: ...
|
|
def storage_context(self) -> SerializationStorageContext: ...
|
|
...
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
class SerializationStorageContext(object):
|
|
def __init__(self) -> None: ...
|
|
def has_storage(self, storage: Storage) -> _bool: ...
|
|
def get_or_add_storage(self, storage: Storage) -> _int: ...
|
|
...
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
class DeserializationStorageContext(object):
|
|
def __init__(self) -> None: ...
|
|
def get_storage(self, name: str, dtype: _dtype) -> Tensor: ...
|
|
def has_storage(self, name: str) -> _bool: ...
|
|
def add_storage(self, name: str, tensor: Tensor) -> _int: ...
|
|
...
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
class ConcreteModuleTypeBuilder:
|
|
def __init__(self, obj: Any) -> None: ...
|
|
def set_module_dict(self): ...
|
|
def set_module_list(self): ...
|
|
def set_parameter_list(self): ...
|
|
def set_parameter_dict(self): ...
|
|
def add_attribute(self, name: str, ty: JitType, is_param: _bool, is_buffer: _bool): ...
|
|
def add_module(self, name: str, meta: ConcreteModuleType): ...
|
|
def add_constant(self, name: str, value: Any): ...
|
|
def add_overload(self, method_name: str, overloaded_method_names: List[str]): ...
|
|
def add_builtin_function(self, name: str, symbol_name: str): ...
|
|
def add_failed_attribute(self, name: str, failure_reason: str): ...
|
|
def add_function_attribute(self, name: str, ty: JitType, func: Callable[..., Any]): ...
|
|
def add_ignored_attribute(self, name: str): ...
|
|
def add_ignored_attributes(self, names: List[str]): ...
|
|
def add_forward_hook(self, hook: Callable[..., Any]): ...
|
|
def add_forward_pre_hook(self, pre_hook: Callable[..., Any]): ...
|
|
|
|
class ConcreteModuleType:
|
|
def get_constants(self) -> Dict[str, Any]: ...
|
|
def equals(self, other: 'ConcreteModuleType') -> _bool: ...
|
|
|
|
@staticmethod
|
|
def from_jit_type(ty: JitType) -> ConcreteModuleType: ...
|
|
|
|
class CallStack:
|
|
def __init__(self, name: str, range: SourceRange): ...
|
|
|
|
class ErrorReport:
|
|
def __init__(self, range: SourceRange) -> None: ...
|
|
def what(self) -> str: ...
|
|
|
|
@staticmethod
|
|
def call_stack() -> str: ...
|
|
|
|
class CompilationUnit:
|
|
def __init__(self, lang: str=..., _frames_up: _int=...) -> None: ...
|
|
def find_function(self, name: str) -> ScriptFunction: ...
|
|
def __getattr__(self, name: str) -> ScriptFunction: ...
|
|
def define(self, script: str, rcb: ResolutionCallback=..., _frames_up: _int=...): ...
|
|
def get_interface(self, name: str) -> InterfaceType: ...
|
|
def get_functions(self) -> List[ScriptFunction]: ...
|
|
def create_function(self, name: str, graph: Graph, shouldMangle: _bool=...) -> ScriptFunction: ...
|
|
def get_class(self, name: str) -> ClassType: ...
|
|
|
|
class ScriptObject:
|
|
def setattr(self, name: str, value: Any): ...
|
|
|
|
class ScriptModule(ScriptObject):
|
|
def _method_names(self) -> List[str]: ...
|
|
def _get_method(self, name: str) -> ScriptMethod: ...
|
|
|
|
class LiteScriptModule:
|
|
def __call__(self, *input): ...
|
|
def find_method(self, method_name: str): ...
|
|
def forward(self, *input) -> List[str]: ...
|
|
def run_method(self, method_name: str, *input): ...
|
|
|
|
class ScriptFunction:
|
|
def __call__(self, *args, **kwargs) -> Tensor: ...
|
|
def save(self, filename: str, _extra_files: Dict[str, bytes]) -> None: ...
|
|
def save_to_buffer(self, _extra_files: Dict[str, bytes]) -> bytes: ...
|
|
@property
|
|
def graph(self) -> Graph: ...
|
|
def inlined_graph(self) -> Graph: ...
|
|
def schema(self) -> FunctionSchema: ...
|
|
def code(self) -> str: ...
|
|
def name(self) -> str: ...
|
|
@property
|
|
def qualified_name(self) -> str: ...
|
|
|
|
class ScriptMethod:
|
|
graph: Graph
|
|
@property
|
|
def owner(self) -> ScriptModule: ...
|
|
@property
|
|
def name(self) -> str: ...
|
|
|
|
class ModuleDict:
|
|
def __init__(self, mod: ScriptModule) -> None: ...
|
|
def items(self) -> List[Tuple[str, Any]]: ...
|
|
|
|
class ParameterDict:
|
|
def __init__(self, mod: ScriptModule) -> None: ...
|
|
|
|
class BufferDict:
|
|
def __init__(self, mod: ScriptModule) -> None: ...
|
|
|
|
# Defined in torch/csrc/jit/api/module.h
|
|
class Module:
|
|
...
|
|
|
|
# Defined in torch/csrc/Module.cpp
|
|
def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension
|
|
def _autograd_init() -> _bool: ... # THPAutograd_initExtension
|
|
def _add_docstr(obj: T, doc_obj: str) -> T: ... # THPModule_addDocStr
|
|
def _init_names(arg: Sequence[Type]) -> None: ... # THPModule_initNames
|
|
def _has_distributed() -> _bool: ... # THPModule_hasDistributed
|
|
def _set_default_tensor_type(type) -> None: ... # THPModule_setDefaultTensorType
|
|
def _set_default_dtype(d: _dtype) -> None: ... # THPModule_setDefaultDtype
|
|
def _infer_size(arg1: Size, arg2: Size) -> Size: ... # THPModule_inferSize
|
|
def _crash_if_csrc_asan() -> _int: ... # THPModule_crashIfCsrcASAN
|
|
def _crash_if_csrc_ubsan() -> _int: ... # THPModule_crashIfCsrcUBSAN
|
|
def _crash_if_aten_asan() -> _int: ... # THPModule_crashIfATenASAN
|
|
def _show_config() -> str: ... # THPModule_showConfig
|
|
def _cxx_flags() -> str: ... # THPModule_cxxFlags
|
|
def _parallel_info() -> str: ... # THPModule_parallelInfo
|
|
def _set_backcompat_broadcast_warn(arg: _bool) -> None: ... # THPModule_setBackcompatBroadcastWarn
|
|
def _get_backcompat_broadcast_warn() -> _bool: ... # THPModule_getBackcompatBroadcastWarn
|
|
def _set_backcompat_keepdim_warn(arg: _bool) -> None: ... # THPModule_setBackcompatKeepdimWarn
|
|
def _get_backcompat_keepdim_warn() -> _bool: ... # THPModule_getBackcompatKeepdimWarn
|
|
def get_num_thread() -> _int: ... # THPModule_getNumThreads
|
|
def set_num_threads(nthreads: _int) -> None: ... # THPModule_setNumThreads
|
|
def get_num_interop_threads() -> _int: ... # THPModule_getNumInteropThreads
|
|
def set_num_interop_threads(nthreads: _int) -> None: ... # THPModule_setNumInteropThreads
|
|
def _get_cudnn_enabled() -> _bool: ... # THPModule_userEnabledCuDNN
|
|
def _set_cudnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledCuDNN
|
|
def _get_flash_sdp_enabled() -> _bool: ... # THPModule_userEnabledFusedSDP
|
|
def _set_sdp_use_flash(arg: _bool) -> None: ... # THPModule_setSDPUseFlash
|
|
def _get_mem_efficient_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
|
|
def _set_sdp_use_mem_efficient(arg: _bool) -> None: ... # THPModule_setSDPUseMemEfficient
|
|
def _get_math_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
|
|
def _set_sdp_use_math(arg: _bool) -> None: ... # THPModule_setSDPUseMath
|
|
def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn
|
|
def _set_mkldnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledMkldnn
|
|
def _get_cudnn_benchmark() -> _bool: ... # THPModule_benchmarkCuDNN
|
|
def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN
|
|
def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN
|
|
def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN
|
|
def _get_deterministic_algorithms() -> _bool: ... # THPModule_deterministicAlgorithms
|
|
def _get_deterministic_algorithms_warn_only() -> _bool: ... # THPModule_deterministicAlgorithmsWarnOnly
|
|
def _set_deterministic_algorithms(mode: _bool, *, warn_only: _bool=...) -> None: ... # THPModule_setDeterministicAlgorithms
|
|
def _get_warnAlways() -> _bool: ... # THPModule_warnAlways
|
|
def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways
|
|
def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN
|
|
def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN
|
|
def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS
|
|
def _set_cublas_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuBLAS
|
|
def _get_float32_matmul_precision() -> str: ... #THPModule_float32MatmulPrecision
|
|
def _set_float32_matmul_precision(arg: str) -> None: ... #THPModule_setFloat32MatmulPrecision
|
|
def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... #THPModule_allowFP16ReductionCuBLAS
|
|
def _set_cublas_allow_fp16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowFP16ReductionCuBLAS
|
|
def _set_conj(x: Tensor, conj: _bool) -> None: ...
|
|
def _set_neg(x: Tensor, neg: _bool) -> None: ...
|
|
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...
|
|
def _meta_in_tls_dispatch_include() -> _bool: ...
|
|
def _select_conv_backend(*args, **kwargs) -> ConvBackend: ...
|
|
def _conv_determine_backend_memory_format(input: Tensor, weight: Tensor, backend: ConvBackend) -> memory_format: ...
|
|
def _has_storage(x: Tensor) -> _bool: ...
|
|
def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ...
|
|
# NB: There is no Capsule type in typing, see
|
|
# https://code.activestate.com/lists/python-dev/139675/
|
|
def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack
|
|
def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack
|
|
def _get_cpp_backtrace(frames_to_skip: _int, maximum_number_of_frames: _int) -> str: ... # THPModule_getCppBacktrace
|
|
def set_flush_denormal(arg: _bool) -> _bool: ... # THPModule_setFlushDenormal
|
|
def get_default_dtype() -> _dtype: ... # THPModule_getDefaultDtype
|
|
def _get_default_device() -> str: ... # THPModule_getDefaultDevice
|
|
def _get_qengine() -> _int: ... # THPModule_qEngine
|
|
def _set_qengine(qegine: _int) -> None: ... # THPModule_setQEngine
|
|
def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines
|
|
def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK
|
|
def _set_default_mobile_cpu_allocator() -> None: ... # THPModule_setDefaultMobileCPUAllocator
|
|
def _unset_default_mobile_cpu_allocator() -> None: ... # THPModule_unsetDefaultMobileCPUAllocator
|
|
def _is_torch_function_enabled() -> _bool: ... # THPModule_isEnabledTorchFunction
|
|
def _has_torch_function(args: Iterable[Any]) -> _bool: ... # THPModule_has_torch_function
|
|
def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary
|
|
def _has_torch_function_variadic(*args: Any) -> _bool: ... # THPModule_has_torch_function_variadic
|
|
def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting
|
|
def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting
|
|
def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython
|
|
def _demangle(str) -> str: ... # c10::demangle
|
|
def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_torch_function
|
|
def _disabled_torch_dispatch_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_dispatch_function
|
|
def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ...
|
|
def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ...
|
|
def _is_mps_available() -> _bool: ...
|
|
class _LinalgBackend:
|
|
Default: _LinalgBackend
|
|
Cusolver: _LinalgBackend
|
|
Magma: _LinalgBackend
|
|
|
|
class ConvBackend(Enum):
|
|
...
|
|
|
|
# Defined in `valgrind.h` and `callgrind.h` respecitively.
|
|
def _valgrind_supported_platform() -> _bool: ... # NVALGRIND
|
|
def _valgrind_toggle() -> None: ... # CALLGRIND_TOGGLE_COLLECT
|
|
def _valgrind_toggle_and_dump_stats() -> None: ... # CALLGRIND_TOGGLE_COLLECT and CALLGRIND_DUMP_STATS
|
|
|
|
has_openmp: _bool
|
|
has_mkl: _bool
|
|
has_mps: _bool
|
|
has_lapack: _bool
|
|
has_cuda: _bool
|
|
has_mkldnn: _bool
|
|
has_cudnn: _bool
|
|
has_spectral: _bool
|
|
_GLIBCXX_USE_CXX11_ABI: _bool
|
|
default_generator: Generator
|
|
|
|
# Defined in torch/csrc/autograd/init.cpp
|
|
def _set_grad_enabled(enabled: _bool) -> None: ...
|
|
def is_grad_enabled() -> _bool: ...
|
|
def is_inference_mode_enabled() -> _bool: ...
|
|
def set_autocast_enabled(enabled: _bool) -> None: ...
|
|
def is_autocast_enabled() -> _bool: ...
|
|
def clear_autocast_cache() -> None: ...
|
|
def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
|
|
def is_autocast_cpu_enabled() -> _bool: ...
|
|
def set_autocast_cpu_dtype(dtype: _dtype) -> None: ...
|
|
def set_autocast_gpu_dtype(dtype: _dtype) -> None: ...
|
|
def get_autocast_cpu_dtype() -> _dtype: ...
|
|
def get_autocast_gpu_dtype() -> _dtype: ...
|
|
def autocast_increment_nesting() -> _int: ...
|
|
def autocast_decrement_nesting() -> _int: ...
|
|
def is_autocast_cache_enabled() -> _bool: ...
|
|
def set_autocast_cache_enabled(enabled: _bool) -> None: ...
|
|
def set_anomaly_enabled(enabled: _bool, check_nan: _bool = True) -> None: ...
|
|
def is_anomaly_enabled() -> _bool: ...
|
|
def is_anomaly_check_nan_enabled() -> _bool: ...
|
|
def _enter_dual_level() -> _int: ...
|
|
def _exit_dual_level(level: _int) -> None: ...
|
|
def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ...
|
|
def _unpack_dual(tensor: Tensor, level: _int) -> Tensor: ...
|
|
def __set_forward_AD_enabled(enabled: _bool) -> None: ...
|
|
def __is_forward_AD_enabled() -> _bool: ...
|
|
def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
|
|
def _reset_default_hooks() -> None: ...
|
|
|
|
def _is_torch_function_mode_enabled()-> _bool: ...
|
|
def _set_torch_function_mode(cls: Any) -> None: ...
|
|
def _push_on_torch_function_stack(cls: Any) -> None: ...
|
|
def _pop_torch_function_stack() -> Any: ...
|
|
def _get_function_stack_at(idx: _int) -> Any: ...
|
|
def _len_torch_function_stack() -> _int: ...
|
|
|
|
def _set_torch_dispatch_mode(cls: Any) -> None: ...
|
|
def _push_on_torch_dispatch_stack(cls: Any) -> None: ...
|
|
def _pop_torch_dispatch_stack() -> Any: ...
|
|
def _get_dispatch_stack_at(idx: _int) -> Any: ...
|
|
def _len_torch_dispatch_stack() -> _int: ...
|
|
|
|
class _InferenceMode(object):
|
|
def __init__(self, mode: _bool) -> None: ...
|
|
|
|
class _DisableFuncTorch:
|
|
def __init__(self) -> None: ...
|
|
|
|
class _EnableTorchFunction:
|
|
def __init__(self) -> None: ...
|
|
|
|
class _MultithreadingEnabled:
|
|
def __init__(self, mode: _bool) -> None: ...
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
class LoggerBase(object):
|
|
...
|
|
|
|
class NoopLogger(LoggerBase):
|
|
...
|
|
|
|
class LockingLogger(LoggerBase):
|
|
...
|
|
|
|
class AggregationType(Enum):
|
|
SUM = 0
|
|
AVG = 1
|
|
|
|
class FileCheck(object):
|
|
# TODO (add more FileCheck signature)
|
|
def check_source_highlighted(self, highlight: str) -> 'FileCheck': ...
|
|
def run(self, test_string: str) -> None: ...
|
|
def check(self, test_string: str) -> 'FileCheck': ...
|
|
def check_not(self, test_string: str) -> 'FileCheck': ...
|
|
...
|
|
|
|
# Defined in torch/csrc/jit/python/init.cpp
|
|
class PyTorchFileReader(object):
|
|
@overload
|
|
def __init__(self, name: str) -> None: ...
|
|
@overload
|
|
def __init__(self, buffer: BinaryIO) -> None: ...
|
|
def get_record(self, name: str) -> bytes: ...
|
|
...
|
|
|
|
class PyTorchFileWriter(object):
|
|
@overload
|
|
def __init__(self, name: str) -> None: ...
|
|
@overload
|
|
def __init__(self, buffer: BinaryIO) -> None: ...
|
|
def write_record(self, name: str, data: Union[bytes, _int], size: _int) -> None: ...
|
|
def write_end_of_file(self) -> None: ...
|
|
def set_min_version(self, version: _int) -> None: ...
|
|
def get_all_written_records(self) -> List[str]: ...
|
|
def archive_name(self) -> str: ...
|
|
...
|
|
|
|
def _jit_get_inline_everything_mode() -> _bool: ...
|
|
def _jit_set_inline_everything_mode(enabled: _bool) -> None: ...
|
|
def _jit_get_logging_option() -> str: ...
|
|
def _jit_set_logging_option(option: str) -> None: ...
|
|
def _jit_set_logging_stream(stream_name: str) -> None: ...
|
|
def _jit_pass_cse(Graph) -> _bool: ...
|
|
def _jit_pass_dce(Graph) -> None: ...
|
|
def _jit_pass_lint(Graph) -> None: ...
|
|
|
|
# Defined in torch/csrc/jit/python/python_custome_class.cpp
|
|
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...
|
|
|
|
# Defined in torch/csrc/Module.cpp
|
|
def _rename_privateuse1_backend(backend: str) -> None: ...
|
|
|
|
# Defined in torch/csrc/Generator.cpp
|
|
class Generator(object):
|
|
device: _device
|
|
def __init__(self, device: Union[_device, str, None] = None) -> None: ...
|
|
def get_state(self) -> Tensor: ...
|
|
def set_state(self, _new_state: Tensor) -> Generator: ...
|
|
def manual_seed(self, seed: _int) -> Generator: ...
|
|
def seed(self) -> _int: ...
|
|
def initial_seed(self) -> _int: ...
|
|
|
|
|
|
# Defined in torch/csrc/utils/python_dispatch.cpp
|
|
|
|
class _DispatchOperatorHandle:
|
|
def schema(self) -> FunctionSchema: ...
|
|
|
|
class _DispatchModule:
|
|
def def_(self, schema: str, alias: str = "") -> _DispatchModule: ...
|
|
def def_legacy(self, schema: str) -> _DispatchModule: ...
|
|
def def_name_t_t(self, name: str, dispatch: str, debug: str = "default_def_name_t_t") -> _DispatchModule: ...
|
|
def def_schema_t_t(self, schema: str, dispatch: str, alias: str, debug: str = "default_def_schema_t_t") -> _DispatchModule: ...
|
|
def impl_t_t(self, name: str, dispatch: str, debug: str = "impl_t_t") -> _DispatchModule: ...
|
|
def impl_tt_t(self, name: str, dispatch: str, debug: str = "impl_tt_t") -> _DispatchModule: ...
|
|
def impl(self, name: str, dispatch: str, func: Callable) -> _DispatchModule: ...
|
|
def define(self, schema: str, alias: str = "") -> _DispatchModule: ...
|
|
def fallback_fallthrough(self, dispatch: str = "") -> _DispatchModule: ...
|
|
|
|
def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> _DispatchModule: ...
|
|
def _dispatch_dump(name: str) -> str: ...
|
|
def _dispatch_dump_table(name: str) -> str: ...
|
|
def _dispatch_check_invariants(name: str) -> None: ...
|
|
def _dispatch_check_all_invariants() -> None: ...
|
|
def _dispatch_has_kernel(name: str) -> _bool: ...
|
|
def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: _dispatchkey) -> _bool: ...
|
|
def _dispatch_has_kernel_for_any_dispatch_key(name: str, dispatch_key_set: DispatchKeySet) -> _bool: ...
|
|
def _dispatch_has_computed_kernel_for_dispatch_key(name: str, dispatch: _dispatchkey) -> _bool: ...
|
|
def _dispatch_find_dangling_impls() -> List[str]: ...
|
|
def _dispatch_get_all_op_names() -> List[str]: ...
|
|
def _dispatch_tls_set_dispatch_key_excluded(dispatch: _dispatchkey, val: _bool) -> None: ...
|
|
def _dispatch_tls_is_dispatch_key_excluded(dispatch: _dispatchkey) -> _bool: ...
|
|
def _dispatch_isTensorSubclassLike(tensor: Tensor) -> _bool: ...
|
|
def _dispatch_key_name(dispatch: _dispatchkey) -> str: ...
|
|
def _dispatch_key_parse(dispatch: _dispatchkey) -> DispatchKey: ...
|
|
def _dispatch_num_backends() -> _int: ...
|
|
|
|
class DispatchKey(Enum):
|
|
${dispatch_key_hints}
|
|
|
|
class DispatchKeySet:
|
|
def __or__(self, other: DispatchKeySet) -> DispatchKeySet: ...
|
|
def __sub__(self, other: DispatchKeySet) -> DispatchKeySet: ...
|
|
def __and__(self, other: DispatchKeySet) -> DispatchKeySet: ...
|
|
def highestPriorityTypeId(self) -> DispatchKey: ...
|
|
def has(self, k: _dispatchkey) -> _bool: ...
|
|
def __repr__(self) -> str: ...
|
|
|
|
_dispatch_autogradother_backends: DispatchKeySet
|
|
def _dispatch_has_backend_fallback(dispatch: _dispatchkey) -> _bool: ...
|
|
def _dispatch_keyset_full_after(t: _dispatchkey) -> DispatchKeySet: ...
|
|
def _dispatch_keyset_to_string(keyset: DispatchKeySet) -> str: ...
|
|
def _dispatch_get_backend_keyset_from_autograd(dispatch: _dispatchkey) -> DispatchKeySet: ...
|
|
def _dispatch_keys(tensor: Tensor) -> DispatchKeySet: ...
|
|
def _dispatch_tls_local_exclude_set() -> DispatchKeySet: ...
|
|
def _dispatch_tls_local_include_set() -> DispatchKeySet: ...
|
|
def _dispatch_is_included_in_alias(dispatch_a: _dispatchkey, dispatch_b: _dispatchkey) -> _bool: ...
|
|
|
|
class ExcludeDispatchKeyGuard:
|
|
pass
|
|
|
|
class _AutoDispatchBelowAutograd:
|
|
pass
|
|
|
|
def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ...
|
|
def _dispatch_get_registrations_for_dispatch_key(dispatch_key: str = "") -> List[str]: ...
|
|
|
|
def _are_functorch_transforms_active() -> _bool: ...
|
|
|
|
# Define in torch/csrc/autograd/init.cpp
|
|
class _DisablePythonDispatcher(object):
|
|
pass
|
|
|
|
class _EnablePythonDispatcher(object):
|
|
pass
|
|
|
|
def _set_python_dispatcher(dispatcher: object) -> None: ...
|
|
|
|
|
|
# Defined in torch/csrc/utils/init.cpp
|
|
class BenchmarkConfig(object):
|
|
num_calling_threads: _int
|
|
num_worker_threads: _int
|
|
num_warmup_iters: _int
|
|
num_iters: _int
|
|
profiler_output_path: str
|
|
|
|
class BenchmarkExecutionStats(object):
|
|
latency_avg_ms: _float
|
|
num_iters: _int
|
|
|
|
class ThroughputBenchmark(object):
|
|
def __init__(self, module: Any) -> None: ...
|
|
def add_input(self, *args: Any, **kwargs: Any) -> None: ...
|
|
def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
|
|
def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ...
|
|
|
|
# Defined in torch/csrc/Storage.cpp
|
|
${legacy_storage_base_hints}
|
|
|
|
# TODO: where
|
|
${legacy_class_hints}
|
|
|
|
# Defined in torch/csrc/autograd/python_engine.cpp
|
|
class _ImperativeEngine:
|
|
...
|
|
|
|
# Defined in torch/csrc/autograd/python_variable.cpp
|
|
class _TensorMeta(type):
|
|
pass
|
|
|
|
# Defined in torch/csrc/autograd/python_variable.cpp
|
|
class _TensorBase(metaclass=_TensorMeta):
|
|
requires_grad: _bool
|
|
shape: Size
|
|
data: Tensor
|
|
names: List[str]
|
|
device: _device
|
|
dtype: _dtype
|
|
layout: _layout
|
|
real: Tensor
|
|
imag: Tensor
|
|
T: Tensor
|
|
H: Tensor
|
|
mT: Tensor
|
|
mH: Tensor
|
|
ndim: _int
|
|
output_nr: _int
|
|
_version: _int
|
|
_base: Optional[Tensor]
|
|
_cdata: _int
|
|
grad_fn: Any
|
|
_grad_fn: Any
|
|
_grad: Optional[Tensor]
|
|
grad: Optional[Tensor]
|
|
_backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]]
|
|
${tensor_method_hints}
|
|
|
|
# Defined in torch/csrc/multiprocessing/init.cpp
|
|
def _multiprocessing_init() -> None: ...
|
|
|
|
# Defined in torch/csrc/cuda/Module.cpp
|
|
def _cuda_getCurrentStream(device: _int) -> _int: ...
|
|
def _cuda_getCurrentRawStream(device: _int) -> _int: ...
|
|
def _cuda_getDefaultStream(device: _int) -> _int: ...
|
|
def _cuda_getCurrentBlasHandle() -> _int: ...
|
|
def _cuda_clearCublasWorkspaces() -> None: ...
|
|
def _cuda_setDevice(device: _int) -> None: ...
|
|
def _cuda_getDevice() -> _int: ...
|
|
def _cuda_getDeviceCount() -> _int: ...
|
|
def _cuda_set_sync_debug_mode(warn_level: Union[_int, str]) -> None: ...
|
|
def _cuda_get_sync_debug_mode() -> _int: ...
|
|
def _cuda_sleep(cycles: _int) -> None: ...
|
|
def _cuda_synchronize() -> None: ...
|
|
def _cuda_ipc_collect() -> None: ...
|
|
def _cuda_getArchFlags() -> Optional[str]: ...
|
|
def _cuda_init() -> None: ...
|
|
def _cuda_setStream(cuda_stream: _int) -> None: ...
|
|
def _cuda_getCompiledVersion() -> _int: ...
|
|
def _cuda_cudaHostAllocator() -> _int: ...
|
|
def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ...
|
|
def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ...
|
|
def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ...
|
|
def _cuda_setMemoryFraction(fraction: _float, device: _int) -> None: ...
|
|
def _cuda_emptyCache() -> None: ...
|
|
def _cuda_memoryStats(device: _int) -> Dict[str, Any]: ...
|
|
def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ...
|
|
def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
|
|
def _cuda_memorySnapshot() -> Dict[str, Any]: ...
|
|
def _cuda_recordMemoryHistory(enabled: _bool, record_context: _bool, record_context_cpp: _bool, alloc_trace_max_entries: _int, alloc_trace_record_context: _bool) -> None: ...
|
|
def _cuda_getAllocatorBackend() -> str: ...
|
|
def _cuda_lock_mutex() -> None: ...
|
|
def _cuda_unlock_mutex() -> None: ...
|
|
def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ...
|
|
def _cuda_jiterator_compile_and_launch_kernel(code_string: str,
|
|
kernel_name: str,
|
|
return_by_ref: _bool,
|
|
num_outputs: _int,
|
|
tensors: Tuple,
|
|
kwargs: Dict[str, Union[_int, _float, _bool]]) -> Tensor: ...
|
|
def _cuda_get_cudnn_benchmark_limit() -> _int: ...
|
|
def _cuda_set_cudnn_benchmark_limit(arg: _int) -> None: ...
|
|
def _nccl_version() -> _int: ...
|
|
def _nccl_unique_id() -> bytes: ...
|
|
def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ...
|
|
def _nccl_reduce(input: Sequence[Tensor],
|
|
output: Tensor,
|
|
root: _int,
|
|
op: _int,
|
|
streams: Optional[Sequence[_CudaStreamBase]],
|
|
comms: Optional[Sequence[object]]) -> None: ...
|
|
def _nccl_all_reduce(input: Sequence[Tensor],
|
|
output: Sequence[Tensor],
|
|
op: _int,
|
|
streams: Optional[Sequence[_CudaStreamBase]],
|
|
comms: Optional[Sequence[object]]) -> None: ...
|
|
def _nccl_broadcast(input: Sequence[Tensor],
|
|
root: _int,
|
|
streams: Optional[Sequence[_CudaStreamBase]],
|
|
comms: Optional[Sequence[object]]) -> None: ...
|
|
def _nccl_all_gather(input: Sequence[Tensor],
|
|
output: Sequence[Tensor],
|
|
streams: Optional[Sequence[_CudaStreamBase]],
|
|
comms: Optional[Sequence[object]]) -> None: ...
|
|
def _nccl_reduce_scatter(input: Sequence[Tensor],
|
|
output: Sequence[Tensor],
|
|
op: _int,
|
|
streams: Optional[Sequence[_CudaStreamBase]],
|
|
comms: Optional[Sequence[object]]) -> None: ...
|
|
def _rocm_is_backward_pass() -> _bool: ...
|
|
|
|
|
|
class _CudaDeviceProperties:
|
|
name: str
|
|
major: _int
|
|
minor: _int
|
|
multi_processor_count: _int
|
|
total_memory: _int
|
|
is_integrated: _int
|
|
is_multi_gpu_board: _int
|
|
|
|
# Defined in torch/csrc/cuda/python_comm.cpp
|
|
def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ...
|
|
def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ...
|
|
def _broadcast_coalesced(
|
|
tensors: List[Tensor],
|
|
devices: List[_int],
|
|
buffer_size: _int
|
|
) -> List[List[Tensor]]: ...
|
|
|
|
def _scatter(tensor: Tensor, devices: List[_int], chunk_sizes: Optional[List[_int]], dim: _int, streams: Optional[List[Stream]]) -> List[Tensor]: ...
|
|
def _scatter_out(tensor: Tensor, out_tensors: List[Tensor], dim: _int, streams: Optional[List[Stream]]) -> List[Tensor]: ...
|
|
def _gather(tensors: List[Tensor], dim: _int, destination_index: Optional[_int]) -> Tensor: ...
|
|
def _gather_out(tensors: List[Tensor], out_tensor: Tensor, dim: _int) -> Tensor: ...
|
|
|
|
# Defined in torch/csrc/cuda/Stream.cpp
|
|
class _CudaStreamBase:
|
|
_cdata: _int
|
|
device: _device
|
|
cuda_stream: _int
|
|
priority: _int
|
|
|
|
def __new__(self, priority: _int = 0, _cdata: _int = 0, stream_ptr: _int = 0) -> _CudaStreamBase: ...
|
|
def query(self) -> _bool: ...
|
|
def synchronize(self) -> None: ...
|
|
def priority_range(self) -> Tuple[_int, _int]: ...
|
|
|
|
# Defined in torch/csrc/cuda/Event.cpp
|
|
class _CudaEventBase:
|
|
device: _device
|
|
cuda_event: _int
|
|
|
|
def __new__(cls, enable_timing: _bool = False, blocking: _bool = False, interprocess: _bool = False) -> _CudaEventBase: ...
|
|
@classmethod
|
|
def from_ipc_handle(cls, device: _device, ipc_handle: bytes) -> _CudaEventBase: ...
|
|
def record(self, stream: _CudaStreamBase) -> None: ...
|
|
def wait(self, stream: _CudaStreamBase) -> None: ...
|
|
def query(self) -> _bool: ...
|
|
def elapsed_time(self, other: _CudaEventBase) -> _float: ...
|
|
def synchronize(self) -> None: ...
|
|
def ipc_handle(self) -> bytes: ...
|
|
|
|
# Defined in torch/csrc/cuda/Graph.cpp
|
|
class _CUDAGraph:
|
|
def capture_begin(self,
|
|
pool: Optional[Tuple[_int, _int]]=...) -> None: ...
|
|
def capture_end(self) -> None: ...
|
|
def replay(self) -> None: ...
|
|
def reset(self) -> None: ...
|
|
def pool(self) -> Tuple[_int, _int]: ...
|
|
|
|
def _cuda_isCurrentStreamCapturing() -> _bool: ...
|
|
|
|
def _graph_pool_handle() -> Tuple[_int, _int]: ...
|
|
|
|
# Defined in torch/csrc/DataLoader.cpp
|
|
def _set_worker_signal_handlers(*arg: Any) -> None: ... # THPModule_setWorkerSignalHandlers
|
|
def _set_worker_pids(key: _int, child_pids: Tuple[_int, ...]) -> None: ... # THPModule_setWorkerPIDs
|
|
def _remove_worker_pids(loader_id: _int) -> None: ... # THPModule_removeWorkerPIDs
|
|
def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails
|
|
|
|
# Defined in torch/csrc/jit/python/python_tracer.cpp
|
|
class TracingState:
|
|
def push_scope(self, scope_name: str) -> None: ...
|
|
def pop_scope(self) -> None: ...
|
|
def current_scope(self) -> str: ...
|
|
def set_graph(self, graph: Graph) -> None: ...
|
|
def graph(self) -> Graph: ...
|
|
...
|
|
|
|
def _create_graph_by_tracing(
|
|
func: Callable[..., Any],
|
|
inputs: Any,
|
|
var_name_lookup_fn: Callable[[Tensor], str],
|
|
strict: Any,
|
|
force_outplace: Any,
|
|
self: Any = None,
|
|
argument_names: List[str] = []
|
|
) -> Tuple[Graph, Stack]: ...
|
|
def _tracer_warn_use_python(): ...
|
|
def _get_tracing_state() -> TracingState: ...
|
|
|
|
# Defined in torch/csrc/jit/python/python_ir.cpp
|
|
# Not actually defined in python_ir.cpp, not sure where they are.
|
|
class IValue:
|
|
...
|
|
Stack = List[IValue]
|
|
|
|
class JitType:
|
|
annotation_str : str
|
|
def isSubtypeOf(self, other: JitType) -> _bool: ...
|
|
def with_dtype(self, dtype: _dtype) -> JitType: ...
|
|
def with_sizes(self, sizes: List[Optional[_int]]) -> JitType: ...
|
|
def kind(self) -> str: ...
|
|
def scalarType(self) -> Optional[str]: ...
|
|
|
|
class InferredType:
|
|
def __init__(self, arg: Union[JitType, str]): ...
|
|
def type(self) -> JitType: ...
|
|
def success(self) -> _bool: ...
|
|
def reason(self) -> str: ...
|
|
|
|
R = TypeVar('R', bound=JitType)
|
|
|
|
class AnyType(JitType):
|
|
@staticmethod
|
|
def get() -> AnyType: ...
|
|
|
|
class NoneType(JitType):
|
|
@staticmethod
|
|
def get() -> NoneType: ...
|
|
|
|
class BoolType(JitType):
|
|
@staticmethod
|
|
def get() -> BoolType: ...
|
|
|
|
class FloatType(JitType):
|
|
@staticmethod
|
|
def get() -> FloatType: ...
|
|
|
|
class ComplexType(JitType):
|
|
@staticmethod
|
|
def get() -> ComplexType: ...
|
|
|
|
class IntType(JitType):
|
|
@staticmethod
|
|
def get() -> IntType: ...
|
|
|
|
class NumberType(JitType):
|
|
@staticmethod
|
|
def get() -> NumberType: ...
|
|
|
|
class StringType(JitType):
|
|
@staticmethod
|
|
def get() -> StringType: ...
|
|
|
|
class DeviceObjType(JitType):
|
|
@staticmethod
|
|
def get() -> DeviceObjType: ...
|
|
|
|
class StreamObjType(JitType):
|
|
@staticmethod
|
|
def get() -> StreamObjType: ...
|
|
|
|
class ListType(JitType):
|
|
def __init__(self, a: JitType) -> None: ...
|
|
def getElementType(self) -> JitType: ...
|
|
|
|
@staticmethod
|
|
def ofInts() -> ListType: ...
|
|
@staticmethod
|
|
def ofTensors() -> ListType: ...
|
|
@staticmethod
|
|
def ofFloats() -> ListType: ...
|
|
@staticmethod
|
|
def ofComplexDoubles() -> ListType: ...
|
|
@staticmethod
|
|
def ofBools() -> ListType: ...
|
|
|
|
class DictType(JitType):
|
|
def __init__(self, key: JitType, value: JitType) -> None: ...
|
|
def getKeyType(self) -> JitType: ...
|
|
def getValueType(self) -> JitType: ...
|
|
|
|
class TupleType(JitType):
|
|
def __init__(self, a: List[Optional[JitType]]) -> None: ...
|
|
def elements(self) -> List[JitType]: ...
|
|
|
|
class UnionType(JitType):
|
|
def __init__(self, a: List[JitType]) -> None: ...
|
|
|
|
class ClassType(JitType):
|
|
def __init__(self, qualified_name: str) -> None: ...
|
|
|
|
class InterfaceType(JitType):
|
|
def __init__(self, qualified_name: str) -> None: ...
|
|
def getMethod(self, name: str) -> Optional[FunctionSchema]: ...
|
|
def getMethodNames(self) -> List[str]: ...
|
|
|
|
class OptionalType(JitType, Generic[R]):
|
|
def __init__(self, a: JitType) -> None: ...
|
|
def getElementType(self) -> JitType: ...
|
|
|
|
@staticmethod
|
|
def ofTensor() -> OptionalType: ...
|
|
|
|
class FutureType(JitType):
|
|
def __init__(self, a: JitType) -> None: ...
|
|
def getElementType(self) -> JitType: ...
|
|
|
|
class RRefType(JitType):
|
|
def __init__(self, a: JitType) -> None: ...
|
|
|
|
class EnumType(JitType):
|
|
def __init__(
|
|
self,
|
|
qualified_name: str,
|
|
value_type: JitType,
|
|
enum_names_values: List[Any]
|
|
) -> None:
|
|
...
|
|
|
|
class TensorType(JitType):
|
|
@classmethod
|
|
def get(cls) -> TensorType: ...
|
|
@classmethod
|
|
def getInferred(cls) -> TensorType: ...
|
|
def with_sizes(self, other: Optional[List[Optional[_int]]]) -> TensorType: ...
|
|
def sizes(self) -> Optional[List[_int]]: ...
|
|
def varyingSizes(self) -> Optional[List[Optional[_int]]]: ...
|
|
def strides(self) -> Optional[List[_int]]: ...
|
|
def device(self) -> Optional[_device]: ...
|
|
def dim(self) -> _int: ...
|
|
def dtype(self) -> Optional[_dtype]: ...
|
|
@staticmethod
|
|
def create_from_tensor(t: Tensor) -> TensorType: ...
|
|
|
|
# Defined in torch/csrc/jit/python/python_tree_views.cpp
|
|
class SourceRange:
|
|
...
|
|
|
|
class TreeView:
|
|
...
|
|
|
|
class Ident(TreeView):
|
|
@property
|
|
def name(self) -> str: ...
|
|
|
|
class ClassDef(TreeView):
|
|
...
|
|
|
|
class Def(TreeView):
|
|
def name(self) -> Ident: ...
|
|
|
|
class Decl(TreeView):
|
|
...
|
|
|
|
# Defined in torch/csrc/distributed/rpc/init.cpp
|
|
def _rpc_init() -> _bool: ...
|
|
|
|
# Defined in torch/csrc/distributed/autograd/init.cpp
|
|
def _dist_autograd_init() -> _bool: ...
|
|
|
|
# Defined in torch/csrc/distributed/c10d/init.cpp
|
|
def _c10d_init() -> _bool: ...
|
|
|
|
# Defined in torch/csrc/distributed/rpc/testing/init.cpp
|
|
def _faulty_agent_init() -> _bool: ...
|
|
|
|
def _enable_minidumps(directory: str) -> None: ...
|
|
def _disable_minidumps() -> None: ...
|
|
def _enable_minidumps_on_exceptions() -> None: ...
|
|
def _register_py_class_for_device(device: str, cls: Any) -> None: ...
|
|
def _activate_cuda_trace() -> None: ...
|
|
|
|
# Defined in torch/csrc/Module.cpp
|
|
def _current_graph_task_id() -> _int: ...
|
|
|
|
class _OutOfMemoryError:
|
|
pass
|