mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Delete torch/__init__.pyi, deferring to direct extension stubs (#38157)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38157 This removes the error prone process of assembling `torch/__init__.pyi` (and frequently forgetting to expose things), since now we can simply rely on the true source file to get things done. Most of the old codegen in gen_pyi.py is now rerouted to various files: - `torch/_C/__init__.pyi` (the dumping pile of all misc bindings) - `torch/_C/_nn.pyi` (NN function bindings) - `torch/_C/_VariableFunctions.pyi` (torch function bindings) `torch.types` grew a bunch more definitions that previously where defined in `torch/__init__.pyi` Some miscellaneous changes - Fixed a bug where we treat single TensorList argument as implying varargs are accepted. This is actually only supported on IntList. This means we can correctly generate a stub for dequantize. - Add missing manual stub for nonzero - Switched torch/onnx/operators.py to directly refer to _C module, since apparently mypy doesn't think that methods prefixed with underscores get reexported. This may be a recurring theme; maybe we need to find a better way to solve it. Because I was really lazy, I dumped namedtuple definitions in both `torch._C` and `torch._C._VariableFunctions`. This is definitely wrong. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D21497400 Pulled By: ezyang fbshipit-source-id: 07b126141c82efaca37be27c07255cb2b9b3f064
This commit is contained in:
committed by
Facebook GitHub Bot
parent
6f396e18c3
commit
6edf340338
3
.gitignore
vendored
3
.gitignore
vendored
@ -48,6 +48,9 @@ third_party/build/
|
||||
tools/shared/_utils_internal.py
|
||||
torch.egg-info/
|
||||
torch/__init__.pyi
|
||||
torch/_C/__init__.pyi
|
||||
torch/_C/_nn.pyi
|
||||
torch/_C/_VariableFunctions.pyi
|
||||
torch/nn/functional.pyi
|
||||
torch/nn/modules/*.pyi
|
||||
torch/csrc/autograd/generated/*
|
||||
|
3
mypy.ini
3
mypy.ini
@ -295,6 +295,9 @@ ignore_errors = True
|
||||
[mypy-torch.nn.intrinsic.qat.modules.conv_fused]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.onnx.operators]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.onnx.symbolic_opset8]
|
||||
ignore_errors = True
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -758,7 +758,7 @@ if __name__ == '__main__':
|
||||
'py.typed',
|
||||
'bin/*',
|
||||
'test/*',
|
||||
'__init__.pyi',
|
||||
'_C/*.pyi',
|
||||
'cuda/*.pyi',
|
||||
'optim/*.pyi',
|
||||
'autograd/*.pyi',
|
||||
|
@ -103,7 +103,6 @@ blacklist = [
|
||||
'div_out',
|
||||
'true_divide', 'true_divide_', 'true_divide_out',
|
||||
'floor_divide', 'floor_divide_', 'floor_divide_out',
|
||||
'dequantize',
|
||||
]
|
||||
|
||||
|
||||
@ -320,7 +319,7 @@ def generate_type_hints(fname, decls, namedtuples, is_tensor=False):
|
||||
numargs = len(decl['arguments'])
|
||||
vararg_pos = int(is_tensor)
|
||||
have_vararg_version = (numargs > vararg_pos and
|
||||
decl['arguments'][vararg_pos]['dynamic_type'] in {'IntArrayRef', 'TensorList'} and
|
||||
decl['arguments'][vararg_pos]['dynamic_type'] in {'IntArrayRef'} and
|
||||
(numargs == vararg_pos + 1 or python_args[vararg_pos + 1] == '*') and
|
||||
(not is_tensor or decl['arguments'][0]['name'] == 'self'))
|
||||
|
||||
@ -328,14 +327,11 @@ def generate_type_hints(fname, decls, namedtuples, is_tensor=False):
|
||||
|
||||
if have_vararg_version:
|
||||
# Two things come into play here: PyTorch has the "magic" that if the first and only positional argument
|
||||
# is an IntArrayRef or TensorList, it will be used as a vararg variant.
|
||||
# is an IntArrayRef, it will be used as a vararg variant.
|
||||
# The following outputs the vararg variant, the "pass a list variant" is output above.
|
||||
# The other thing is that in Python, the varargs are annotated with the element type, not the list type.
|
||||
typelist = decl['arguments'][vararg_pos]['dynamic_type']
|
||||
if typelist == 'IntArrayRef':
|
||||
vararg_type = '_int'
|
||||
else:
|
||||
vararg_type = 'Tensor'
|
||||
vararg_type = '_int'
|
||||
# replace first argument and eliminate '*' if present
|
||||
python_args = ((['self'] if is_tensor else []) + ['*' + decl['arguments'][vararg_pos]['name'] +
|
||||
': ' + vararg_type] + python_args[vararg_pos + 2:])
|
||||
@ -419,6 +415,9 @@ def gen_nn_functional(out):
|
||||
}
|
||||
write(out, 'torch/nn/functional.pyi', stubs, env)
|
||||
|
||||
stubs = CodeTemplate.from_file(os.path.join('torch', '_C', '_nn.pyi.in'))
|
||||
write(out, 'torch/_C/_nn.pyi', stubs, env)
|
||||
|
||||
def gen_nn_pyi(out):
|
||||
gen_nn_functional(out)
|
||||
gen_nn_modules(out)
|
||||
@ -485,7 +484,9 @@ def gen_pyi(declarations_path, out):
|
||||
'def full(size: _size, fill_value: Number, *,'
|
||||
' names: List[Union[str, None]], {}) -> Tensor: ...'
|
||||
.format(FACTORY_PARAMS)],
|
||||
'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...']
|
||||
'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...'],
|
||||
'nonzero': ['def nonzero(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...',
|
||||
'def nonzero(input: Tensor, *, as_tuple: bool=...) -> Tensor: ...'],
|
||||
})
|
||||
for binop in ['mul', 'div', 'true_divide', 'floor_divide']:
|
||||
unsorted_function_hints[binop].append(
|
||||
@ -622,11 +623,14 @@ def gen_pyi(declarations_path, out):
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
# TODO: These are deprecated, maybe we shouldn't type hint them
|
||||
legacy_class_hints = []
|
||||
for c in ('DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
|
||||
'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage'):
|
||||
legacy_class_hints.append('class {}(Storage): ...'.format(c))
|
||||
legacy_storage_base_hints = []
|
||||
for c in ('Double', 'Float', 'Long', 'Int',
|
||||
'Short', 'Char', 'Byte', 'Bool',
|
||||
'Half', 'BFloat16', 'ComplexDouble',
|
||||
'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32'):
|
||||
legacy_storage_base_hints.append('class {}StorageBase(object): ...'.format(c))
|
||||
|
||||
legacy_class_hints = []
|
||||
for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
|
||||
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'):
|
||||
legacy_class_hints.append('class {}(Tensor): ...'.format(c))
|
||||
@ -650,11 +654,15 @@ def gen_pyi(declarations_path, out):
|
||||
'function_hints': function_hints,
|
||||
'tensor_method_hints': tensor_method_hints,
|
||||
'legacy_class_hints': legacy_class_hints,
|
||||
'legacy_storage_base_hints': legacy_storage_base_hints,
|
||||
'dtype_class_hints': dtype_class_hints,
|
||||
}
|
||||
TORCH_TYPE_STUBS = CodeTemplate.from_file(os.path.join('torch', '__init__.pyi.in'))
|
||||
TORCH_C_TYPE_STUBS = CodeTemplate.from_file(os.path.join('torch', '_C', '__init__.pyi.in'))
|
||||
TORCH_C_VARIABLE_FUNCTIONS_TYPE_STUBS = \
|
||||
CodeTemplate.from_file(os.path.join('torch', '_C', '_VariableFunctions.pyi.in'))
|
||||
|
||||
write(out, 'torch/__init__.pyi', TORCH_TYPE_STUBS, env)
|
||||
write(out, 'torch/_C/__init__.pyi', TORCH_C_TYPE_STUBS, env)
|
||||
write(out, 'torch/_C/_VariableFunctions.pyi', TORCH_C_VARIABLE_FUNCTIONS_TYPE_STUBS, env)
|
||||
gen_nn_pyi(out)
|
||||
|
||||
|
||||
|
@ -268,7 +268,8 @@ set(ModulesStubOut
|
||||
${TORCH_SRC_DIR}/nn/modules/upsampling.pyi
|
||||
)
|
||||
add_custom_target(torch_python_stubs DEPENDS
|
||||
"${TORCH_SRC_DIR}/__init__.pyi"
|
||||
"${TORCH_SRC_DIR}/_C/__init__.pyi"
|
||||
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi"
|
||||
"${TORCH_SRC_DIR}/nn/functional.pyi"
|
||||
${ModuleStubOut}
|
||||
)
|
||||
@ -276,7 +277,8 @@ add_custom_target(torch_python_stubs DEPENDS
|
||||
add_dependencies(torch_python_stubs ATEN_CPU_FILES_GEN_TARGET)
|
||||
add_custom_command(
|
||||
OUTPUT
|
||||
"${TORCH_SRC_DIR}/__init__.pyi"
|
||||
"${TORCH_SRC_DIR}/_C/__init__.pyi"
|
||||
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi"
|
||||
"${TORCH_SRC_DIR}/nn/functional.pyi"
|
||||
${ModuleStubOut}
|
||||
COMMAND
|
||||
@ -284,7 +286,8 @@ add_custom_command(
|
||||
--declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
|
||||
DEPENDS
|
||||
"${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
|
||||
"${TORCH_SRC_DIR}/__init__.pyi.in"
|
||||
"${TORCH_SRC_DIR}/_C/__init__.pyi.in"
|
||||
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi.in"
|
||||
"${TORCH_SRC_DIR}/nn/functional.pyi.in"
|
||||
${ModuleStubIn}
|
||||
"${TOOLS_PATH}/pyi/gen_pyi.py"
|
||||
|
14
torch/_C/_VariableFunctions.pyi.in
Normal file
14
torch/_C/_VariableFunctions.pyi.in
Normal file
@ -0,0 +1,14 @@
|
||||
# ${generated_comment}
|
||||
|
||||
from torch import Tensor, Generator, strided, memory_format, contiguous_format
|
||||
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar
|
||||
from torch._six import inf
|
||||
|
||||
from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout
|
||||
|
||||
import builtins
|
||||
|
||||
# REDUNDANT!
|
||||
${namedtuple_defs}
|
||||
|
||||
${function_hints}
|
@ -1,86 +0,0 @@
|
||||
import torch
|
||||
from typing import Optional, TypeVar, Callable, Any
|
||||
|
||||
from . import _nn as _nn
|
||||
from . import _onnx as _onnx
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# Defined in torch/csrc/autograd/python_function.cpp
|
||||
class _FunctionBase(object):
|
||||
# TODO
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/autograd/python_legacy_variable.cpp
|
||||
class _LegacyVariableBase(object):
|
||||
def __init__(
|
||||
self,
|
||||
data: Optional['torch.Tensor']=...,
|
||||
requires_grad: Optional[bool]=...,
|
||||
volatile: Optional[bool]=...,
|
||||
_grad_fn: Optional[_FunctionBase]=...
|
||||
) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/init.cpp
|
||||
def _jit_get_operation(op_name: str) -> Callable: ...
|
||||
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule') -> 'torch.jit.ScriptModule': ...
|
||||
|
||||
# Defined in torch/csrc/Module.cpp
|
||||
def _show_config() -> str: ...
|
||||
def _parallel_info() -> str: ...
|
||||
def _add_docstr(obj: T, doc_obj: str) -> T: ...
|
||||
def _from_dlpack(data: Any) -> 'torch.Tensor': ...
|
||||
def _to_dlpack(data: 'torch.Tensor') -> Any: ...
|
||||
def _set_backcompat_broadcast_warn(arg: bool) -> None: ...
|
||||
def _get_backcompat_broadcast_warn() -> bool: ...
|
||||
def _set_backcompat_keepdim_warn(arg: bool) -> None: ...
|
||||
def _get_backcompat_keepdim_warn() -> bool: ...
|
||||
def _is_xnnpack_enabled() -> bool: ...
|
||||
def _get_mkldnn_enabled() -> bool: ...
|
||||
def _set_mkldnn_enabled(arg: bool) -> None: ...
|
||||
has_openmp: bool
|
||||
has_mkldnn: bool
|
||||
has_mkl: bool
|
||||
|
||||
# Defined in tools/autograd/templates/python_torch_functions.cpp
|
||||
# TODO: This is technically wrong
|
||||
class _VariableFunctions(object):
|
||||
# TODO
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
class FileCheck(object):
|
||||
# TODO
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/Generator.cpp
|
||||
class Generator(object):
|
||||
device: 'torch.device'
|
||||
def get_state(self) -> 'torch.Tensor': ...
|
||||
def set_state(self, _new_state: 'torch.Tensor') -> Generator: ...
|
||||
def manual_seed(self, seed: int) -> Generator: ...
|
||||
def seed(self) -> int: ...
|
||||
def initial_seed(self) -> int: ...
|
||||
|
||||
# Defined in torch/csrc/utils/init.cpp
|
||||
class BenchmarkConfig(object):
|
||||
num_calling_threads: int
|
||||
num_worker_threads: int
|
||||
num_warmup_iters: int
|
||||
num_iters: int
|
||||
profiler_output_path: str
|
||||
|
||||
class BenchmarkExecutionStats(object):
|
||||
latency_avg_ms: float
|
||||
num_iters: int
|
||||
|
||||
class ThroughputBenchmark(object):
|
||||
def __init__(self, module: Any) -> None: ...
|
||||
def add_input(self, *args: Any, **kwargs: Any) -> None: ...
|
||||
def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ...
|
||||
|
||||
# Defined in torch/csrc/autograd/python_variable.cpp
|
||||
# This is gonna need to be code'genned.
|
||||
class _TensorBase(object):
|
||||
...
|
157
torch/_C/__init__.pyi.in
Normal file
157
torch/_C/__init__.pyi.in
Normal file
@ -0,0 +1,157 @@
|
||||
# ${generated_comment}
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar, Type
|
||||
from torch._six import inf
|
||||
|
||||
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Number
|
||||
|
||||
import builtins
|
||||
|
||||
from . import _nn as _nn
|
||||
from . import _onnx as _onnx
|
||||
from . import _VariableFunctions as _VariableFunctions
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# Defined in torch/csrc/Device.cpp
|
||||
class device:
|
||||
type: str
|
||||
index: _int
|
||||
|
||||
@overload
|
||||
def __init__(self, device: Union[_int, str]) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(self, type: str, index: _int) -> None: ...
|
||||
|
||||
# TODO: __reduce__
|
||||
|
||||
# Defined in torch/csrc/Size.cpp
|
||||
class Size(Tuple[_int, ...]):
|
||||
# TODO: numel, __reduce__
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/Dtype.cpp
|
||||
class dtype:
|
||||
# TODO: is_floating_point, is_complex, is_Signed, __reduce__
|
||||
...
|
||||
|
||||
${dtype_class_hints}
|
||||
|
||||
# Defined in torch/csrc/Layout.cpp
|
||||
class layout:
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/utils/tensor_layouts.cpp
|
||||
strided : layout = ...
|
||||
sparse_coo : layout = ...
|
||||
|
||||
# Defined in torch/csrc/MemoryFormat.cpp
|
||||
class memory_format: ...
|
||||
|
||||
# Defined in torch/csrc/utils/tensor_memoryformats.cpp
|
||||
contiguous_format: memory_format = ...
|
||||
|
||||
# Defined in torch/csrc/QScheme.cpp
|
||||
class qscheme: ...
|
||||
|
||||
# Defined in torch/csrc/utils/tensor_qschemes.cpp
|
||||
per_tensor_affine: qscheme = ...
|
||||
|
||||
# Defined in torch/csrc/generic/Storage.cpp
|
||||
class Storage: ...
|
||||
|
||||
# Defined in torch/csrc/autograd/python_function.cpp
|
||||
class _FunctionBase(object):
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/autograd/python_legacy_variable.cpp
|
||||
class _LegacyVariableBase(object):
|
||||
def __init__(
|
||||
self,
|
||||
data: Optional[Tensor]=...,
|
||||
requires_grad: Optional[_bool]=...,
|
||||
volatile: Optional[_bool]=...,
|
||||
_grad_fn: Optional[_FunctionBase]=...
|
||||
) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/init.cpp
|
||||
def _jit_get_operation(op_name: str) -> Callable: ...
|
||||
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule') -> 'torch.jit.ScriptModule': ...
|
||||
|
||||
# Defined in torch/csrc/Module.cpp
|
||||
def _init_names(arg: Sequence[Type]) -> None: ...
|
||||
def _show_config() -> str: ...
|
||||
def _parallel_info() -> str: ...
|
||||
def _add_docstr(obj: T, doc_obj: str) -> T: ...
|
||||
def _from_dlpack(data: Any) -> Tensor: ...
|
||||
def _to_dlpack(data: Tensor) -> Any: ...
|
||||
def _set_backcompat_broadcast_warn(arg: _bool) -> None: ...
|
||||
def _get_backcompat_broadcast_warn() -> _bool: ...
|
||||
def _set_backcompat_keepdim_warn(arg: _bool) -> None: ...
|
||||
def _get_backcompat_keepdim_warn() -> _bool: ...
|
||||
def _is_xnnpack_enabled() -> _bool: ...
|
||||
def _get_mkldnn_enabled() -> _bool: ...
|
||||
def _set_mkldnn_enabled(arg: _bool) -> None: ...
|
||||
def _set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
|
||||
def _set_default_dtype(d: _dtype) -> None: ...
|
||||
def _initExtension(shm_manager_path: str) -> None: ...
|
||||
has_openmp: _bool
|
||||
has_mkldnn: _bool
|
||||
has_mkl: _bool
|
||||
_GLIBCXX_USE_CXX11_ABI: _bool
|
||||
|
||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
class FileCheck(object):
|
||||
# TODO
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/Generator.cpp
|
||||
class Generator(object):
|
||||
device: _device
|
||||
def __init__(self, device: Union[_device, str, None] = None) -> None: ...
|
||||
def get_state(self) -> Tensor: ...
|
||||
def set_state(self, _new_state: Tensor) -> Generator: ...
|
||||
def manual_seed(self, seed: _int) -> Generator: ...
|
||||
def seed(self) -> _int: ...
|
||||
def initial_seed(self) -> _int: ...
|
||||
|
||||
# Defined in torch/csrc/utils/init.cpp
|
||||
class BenchmarkConfig(object):
|
||||
num_calling_threads: _int
|
||||
num_worker_threads: _int
|
||||
num_warmup_iters: _int
|
||||
num_iters: _int
|
||||
profiler_output_path: str
|
||||
|
||||
class BenchmarkExecutionStats(object):
|
||||
latency_avg_ms: _float
|
||||
num_iters: _int
|
||||
|
||||
class ThroughputBenchmark(object):
|
||||
def __init__(self, module: Any) -> None: ...
|
||||
def add_input(self, *args: Any, **kwargs: Any) -> None: ...
|
||||
def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ...
|
||||
|
||||
# IDK if these are actually exposed here, hope they are
|
||||
${namedtuple_defs}
|
||||
|
||||
# Defined in torch/csrc/generic/Storage.cpp
|
||||
${legacy_storage_base_hints}
|
||||
|
||||
# TODO: where
|
||||
${legacy_class_hints}
|
||||
|
||||
# Defined in torch/csrc/autograd/python_variable.cpp
|
||||
class _TensorBase(object):
|
||||
requires_grad: _bool
|
||||
shape: Size
|
||||
data: Tensor
|
||||
names: List[str]
|
||||
device: _device
|
||||
dtype: _dtype
|
||||
layout: _layout
|
||||
${tensor_method_hints}
|
@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
# Defined in tools/autograd/templates/python_nn_functions.cpp
|
||||
class _nn(object):
|
||||
...
|
||||
|
||||
${dispatched_hints}
|
@ -4,6 +4,8 @@ import types
|
||||
|
||||
|
||||
class VFModule(types.ModuleType):
|
||||
vf: types.ModuleType
|
||||
|
||||
def __init__(self, name):
|
||||
super(VFModule, self).__init__(name)
|
||||
self.vf = torch._C._VariableFunctions
|
||||
|
@ -23,6 +23,8 @@ from ._utils_internal import get_file_path, prepare_multiprocessing_environment,
|
||||
from .version import __version__
|
||||
from ._six import string_classes as _string_classes
|
||||
|
||||
from typing import Set, Type
|
||||
|
||||
__all__ = [
|
||||
'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type',
|
||||
'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed', 'seed',
|
||||
@ -114,10 +116,10 @@ if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \
|
||||
if not hasattr(_dl_flags, 'RTLD_GLOBAL') or not hasattr(_dl_flags, 'RTLD_LAZY'):
|
||||
try:
|
||||
# next try if DLFCN exists
|
||||
import DLFCN as _dl_flags
|
||||
import DLFCN as _dl_flags # type: ignore
|
||||
except ImportError:
|
||||
# as a last attempt, use compile-time constants
|
||||
import torch._dl as _dl_flags
|
||||
import torch._dl as _dl_flags # type: ignore
|
||||
old_flags = sys.getdlopenflags()
|
||||
sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_LAZY)
|
||||
from torch._C import *
|
||||
@ -139,6 +141,11 @@ else:
|
||||
_load_global_deps()
|
||||
from torch._C import *
|
||||
|
||||
# Appease the type checker; ordinarily this binding is inserted by the
|
||||
# torch._C module initialization code in C
|
||||
if False:
|
||||
import torch._C as _C
|
||||
|
||||
__all__ += [name for name in dir(_C)
|
||||
if name[0] != '_' and
|
||||
not name.endswith('Base')]
|
||||
@ -311,7 +318,7 @@ _storage_classes = {
|
||||
}
|
||||
|
||||
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
|
||||
_tensor_classes = set()
|
||||
_tensor_classes: Set[Type] = set()
|
||||
|
||||
|
||||
################################################################################
|
||||
@ -332,6 +339,13 @@ def manager_path():
|
||||
_C._initExtension(manager_path())
|
||||
del manager_path
|
||||
|
||||
# Appease the type checker: it can't deal with direct setting of globals().
|
||||
# Note that we will see "too many" functions when reexporting this way; there
|
||||
# is not a good way to fix this problem. Perhaps, try to redesign VariableFunctions
|
||||
# so that this import is good enough
|
||||
if False:
|
||||
from torch._C._VariableFunctions import *
|
||||
|
||||
for name in dir(_C._VariableFunctions):
|
||||
if name.startswith('__'):
|
||||
continue
|
||||
|
@ -1,180 +0,0 @@
|
||||
# ${generated_comment}
|
||||
|
||||
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence
|
||||
from torch._six import inf
|
||||
|
||||
import builtins
|
||||
|
||||
# These identifiers are reexported from other modules. These modules
|
||||
# are not mypy-clean yet, so in order to use this stub file usefully
|
||||
# from mypy you will need to specify --follow-imports=silent.
|
||||
# Not all is lost: these imports still enable IDEs like PyCharm to offer
|
||||
# autocomplete.
|
||||
#
|
||||
# Note: Why does the syntax here look so strange? Import visibility
|
||||
# rules in stubs are different from normal Python files! You must use
|
||||
# 'from ... import ... as ...' syntax to cause an identifier to be
|
||||
# exposed (or use a wildcard); regular syntax is not exposed.
|
||||
from .random import set_rng_state as set_rng_state, get_rng_state as get_rng_state, \
|
||||
manual_seed as manual_seed, initial_seed as initial_seed, seed as seed
|
||||
from ._tensor_str import set_printoptions as set_printoptions
|
||||
from .functional import *
|
||||
from .serialization import save as save, load as load
|
||||
from .autograd import no_grad as no_grad, enable_grad as enable_grad, \
|
||||
set_grad_enabled as set_grad_enabled
|
||||
from ._ops import ops
|
||||
from ._classes import classes
|
||||
from . import autograd as autograd
|
||||
from . import cuda as cuda
|
||||
from . import optim as optim
|
||||
from . import nn as nn
|
||||
from . import multiprocessing as multiprocessing
|
||||
from . import sparse as sparse
|
||||
from . import onnx as onnx
|
||||
from . import jit as jit
|
||||
from . import hub as hub
|
||||
from . import random as random
|
||||
from . import distributions as distributions
|
||||
from . import testing as testing
|
||||
from . import quantization as quantization
|
||||
from . import __config__ as __config__
|
||||
from . import __future__ as __future__
|
||||
|
||||
class dtype: ...
|
||||
|
||||
class layout: ...
|
||||
|
||||
strided : layout = ...
|
||||
|
||||
sparse_coo : layout = ...
|
||||
|
||||
class memory_format: ...
|
||||
|
||||
contiguous_format: memory_format = ...
|
||||
|
||||
class qscheme: ...
|
||||
|
||||
per_tensor_affine: qscheme = ...
|
||||
|
||||
# See https://github.com/python/mypy/issues/4146 for why these workarounds
|
||||
# is necessary
|
||||
_int = builtins.int
|
||||
_float = builtins.float
|
||||
_bool = builtins.bool
|
||||
|
||||
class device:
|
||||
type: str
|
||||
index: _int
|
||||
|
||||
@overload
|
||||
def __init__(self, device: Union[_int, str]) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(self, type: str, index: _int) -> None: ...
|
||||
|
||||
class Size(Tuple[_int, ...]): ...
|
||||
|
||||
class Storage: ...
|
||||
|
||||
# See https://github.com/python/mypy/issues/4146 for why these workarounds
|
||||
# is necessary
|
||||
_dtype = dtype
|
||||
_device = device
|
||||
_qscheme = qscheme
|
||||
_size = Union[Size, List[_int], Tuple[_int, ...]]
|
||||
_layout = layout
|
||||
|
||||
# Meta-type for "numeric" things; matches our docs
|
||||
Number = Union[builtins.int, builtins.float, builtins.bool]
|
||||
|
||||
${namedtuple_defs}
|
||||
|
||||
class Generator:
|
||||
device: _device = ...
|
||||
|
||||
@overload
|
||||
def __init__(self, device: Optional[_device]=None) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(self, device: Union[_int, str]) -> None: ...
|
||||
|
||||
# TODO: One downside of doing it this way, is direct use of
|
||||
# torch.tensor.Tensor doesn't get type annotations. Nobody
|
||||
# should really do that, so maybe this is not so bad.
|
||||
class Tensor:
|
||||
requires_grad: _bool = ...
|
||||
grad: Optional[Tensor] = ...
|
||||
data: Tensor = ...
|
||||
names: List[str] = ...
|
||||
@property
|
||||
def dtype(self) -> _dtype: ...
|
||||
@property
|
||||
def shape(self) -> Size: ...
|
||||
@property
|
||||
def device(self) -> _device: ...
|
||||
@property
|
||||
def T(self) -> Tensor: ...
|
||||
@property
|
||||
def grad_fn(self) -> Optional[Any]: ...
|
||||
@property
|
||||
def ndim(self) -> _int: ...
|
||||
@property
|
||||
def layout(self) -> _layout: ...
|
||||
|
||||
${tensor_method_hints}
|
||||
|
||||
# Manually defined methods from torch/tensor.py
|
||||
def __len__(self) -> _int: ...
|
||||
def __iter__(self) -> Iterator[Tensor]: ...
|
||||
def __contains__(self, item: Union[Tensor, Number]) -> _bool: ...
|
||||
def register_hook(self, hook: Callable) -> Any: ...
|
||||
def retain_grad(self) -> None: ...
|
||||
def is_shared(self) -> _bool: ...
|
||||
def share_memory_(self) -> None: ...
|
||||
# NamedTensor requires several manually created annotations
|
||||
def align_to(self, *names: Union[str, ellipsis]) -> Tensor: ...
|
||||
def refine_names(self, *names: Union[str, ellipsis, None]) -> Tensor: ...
|
||||
def rename(self, *names: Union[str, None], **rename_map: str) -> Tensor: ...
|
||||
def unflatten(self, dim: Union[str, _int], namedshape: Sequence[Tuple[str, _int]]) -> Tensor: ...
|
||||
# TODO: fill in the types for these, or otherwise figure out some
|
||||
# way to not have to write these out again...
|
||||
def nonzero(self, *, as_tuple=True): ...
|
||||
def norm(self, p="fro", dim=None, keepdim=False): ...
|
||||
def stft(self, n_fft, hop_length=None, win_length=None, window=None,
|
||||
center=True, pad_mode='reflect', normalized=False, onesided=True): ...
|
||||
def istft(self, n_fft, hop_length=None, win_length=None, window=None,
|
||||
center=True, normalized=False, onesided=True, length=None): ...
|
||||
def split(self, split_size, dim=0): ...
|
||||
def unique(self, sorted=True, return_inverse=False, dim=None): ...
|
||||
def unique_consecutive(self, sorted=True, return_inverse=False, return_counts=False, dim=None): ...
|
||||
def lu(self, pivot=True, get_infos=False): ...
|
||||
|
||||
${function_hints}
|
||||
|
||||
${legacy_class_hints}
|
||||
|
||||
${dtype_class_hints}
|
||||
|
||||
# Pure Python functions defined in torch/__init__.py
|
||||
|
||||
def typename(obj) -> str: ...
|
||||
def is_tensor(obj) -> _bool: ...
|
||||
def is_storage(obj) -> _bool: ...
|
||||
def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
|
||||
def set_default_dtype(d : _dtype) -> None: ...
|
||||
def manager_path() -> str: ...
|
||||
def compiled_with_cxx11_abi() -> _bool: ...
|
||||
|
||||
# The return value of this function depends on the value of `as_tuple`,
|
||||
# (similar to `unique`, `lu`, etc.); as such, it is not
|
||||
# possible to type correctly
|
||||
def nonzero(input: Tensor, *, out: Optional[Tensor]=None, as_tuple: Optional[_bool]=None): ...
|
||||
|
||||
# we can't auto generate hints for torch.dequantize because it will generate
|
||||
# `dequantize(*tensors) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...`
|
||||
# which overlaps with
|
||||
# `dequantize(self: Tensor) -> Tensor: ...
|
||||
@overload
|
||||
def dequantize(self: Tensor) -> Tensor: ...
|
||||
@overload
|
||||
def dequantize(tensors: Union[Tuple[Tensor, ...], List[Tensor]]) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...
|
@ -1,4 +1,5 @@
|
||||
from .. import Tensor, _size
|
||||
from torch import Tensor
|
||||
from torch.types import _size
|
||||
from typing import Any, Optional, Tuple, Dict, List, Callable
|
||||
from .common_types import _ratio_any_t
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from .module import Module
|
||||
from typing import Optional
|
||||
from ... import Tensor, _size
|
||||
from torch import Tensor
|
||||
from torch.types import _size
|
||||
from ..common_types import _size_any_t, _maybe_indices_t, _size_1_t, _size_2_t, _size_3_t, _ratio_3_t, _ratio_2_t
|
||||
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from collections import namedtuple
|
||||
from typing import Any, Optional, overload, Union, TypeVar, Tuple, Sequence
|
||||
from ... import Tensor, _dtype, _device
|
||||
from torch import Tensor
|
||||
from torch.types import _dtype, _device
|
||||
|
||||
PackedSequence_ = namedtuple('PackedSequence', ['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices'])
|
||||
|
||||
|
@ -1,7 +1,26 @@
|
||||
import torch
|
||||
from typing import Union, Sequence
|
||||
from typing import Union, Sequence, List, Tuple
|
||||
|
||||
import builtins
|
||||
|
||||
# Convenience aliases for common composite types that we need
|
||||
# to talk about in PyTorch
|
||||
|
||||
_TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]]
|
||||
|
||||
# In some cases, these basic types are shadowed by corresponding
|
||||
# top-level values. The underscore variants let us refer to these
|
||||
# types. See https://github.com/python/mypy/issues/4146 for why these
|
||||
# workarounds is necessary
|
||||
_int = builtins.int
|
||||
_float = builtins.float
|
||||
_bool = builtins.bool
|
||||
|
||||
_dtype = torch.dtype
|
||||
_device = torch.device
|
||||
_qscheme = torch.qscheme
|
||||
_size = Union[torch.Size, List[_int], Tuple[_int, ...]]
|
||||
_layout = torch.layout
|
||||
|
||||
# Meta-type for "numeric" things; matches our docs
|
||||
Number = Union[builtins.int, builtins.float, builtins.bool]
|
||||
|
@ -1,6 +1,8 @@
|
||||
from typing import Any, Callable, TypeVar, Generic, overload, Sequence, List, Optional
|
||||
from . import Dataset, Sampler
|
||||
|
||||
from torch.utils.data._utils.worker import get_worker_info as get_worker_info
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
T = TypeVar('T')
|
||||
_worker_init_fn_t = Callable[[int], None]
|
||||
|
@ -6,12 +6,12 @@ T = TypeVar('T')
|
||||
class Dataset(Generic[T_co]):
|
||||
def __getitem__(self, index: int) -> T_co: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __add__(self, other: T_co) -> 'ConcatDataset[T_co]': ...
|
||||
# error: Cannot use a covariant type variable as a parameter
|
||||
def __add__(self, other: T_co) -> 'ConcatDataset[T_co]': ... # type: ignore
|
||||
|
||||
class IterableDataset(Dataset[T_co]):
|
||||
def __iter__(self) -> Iterable[T_co]: ...
|
||||
|
||||
|
||||
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
|
||||
tensors: List[Tensor]
|
||||
|
||||
@ -23,6 +23,9 @@ class ConcatDataset(Dataset[T_co]):
|
||||
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None: ...
|
||||
|
||||
class ChainDataset(Dataset[T_co]):
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None: ...
|
||||
|
||||
class Subset(Dataset[T_co]):
|
||||
dataset: Dataset[T_co]
|
||||
indices: Sequence[int]
|
||||
|
@ -4,6 +4,6 @@ from . import Sampler, Dataset
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
class DistributedSampler(Sampler[T_co]):
|
||||
def __init__(self, dataset: Dataset, num_replicas: Optional[int]=..., rank: Optional[int]=..., shuffle: bool=...): ...
|
||||
def __iter__(self) -> Iterator[int]: ...
|
||||
def __iter__(self) -> Iterator[T_co]: ...
|
||||
def __len__(self) -> int: ...
|
||||
def set_epoch(self, epoch: int) -> None: ...
|
||||
|
Reference in New Issue
Block a user