Delete torch/__init__.pyi, deferring to direct extension stubs (#38157)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38157

This removes the error prone process of assembling `torch/__init__.pyi`
(and frequently forgetting to expose things), since now we can simply
rely on the true source file to get things done.  Most of the old
codegen in gen_pyi.py is now rerouted to various files:

- `torch/_C/__init__.pyi` (the dumping pile of all misc bindings)
- `torch/_C/_nn.pyi` (NN function bindings)
- `torch/_C/_VariableFunctions.pyi` (torch function bindings)

`torch.types` grew a bunch more definitions that previously where
defined in `torch/__init__.pyi`

Some miscellaneous changes

- Fixed a bug where we treat single TensorList argument as implying
  varargs are accepted. This is actually only supported on IntList.
  This means we can correctly generate a stub for dequantize.
- Add missing manual stub for nonzero
- Switched torch/onnx/operators.py to directly refer to _C module,
  since apparently mypy doesn't think that methods prefixed with
  underscores get reexported.  This may be a recurring theme; maybe
  we need to find a better way to solve it.

Because I was really lazy, I dumped namedtuple definitions in both
`torch._C` and `torch._C._VariableFunctions`.  This is definitely wrong.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D21497400

Pulled By: ezyang

fbshipit-source-id: 07b126141c82efaca37be27c07255cb2b9b3f064
This commit is contained in:
Edward Yang
2020-05-11 07:17:59 -07:00
committed by Facebook GitHub Bot
parent 6f396e18c3
commit 6edf340338
19 changed files with 263 additions and 296 deletions

3
.gitignore vendored
View File

@ -48,6 +48,9 @@ third_party/build/
tools/shared/_utils_internal.py
torch.egg-info/
torch/__init__.pyi
torch/_C/__init__.pyi
torch/_C/_nn.pyi
torch/_C/_VariableFunctions.pyi
torch/nn/functional.pyi
torch/nn/modules/*.pyi
torch/csrc/autograd/generated/*

View File

@ -295,6 +295,9 @@ ignore_errors = True
[mypy-torch.nn.intrinsic.qat.modules.conv_fused]
ignore_errors = True
[mypy-torch.onnx.operators]
ignore_errors = True
[mypy-torch.onnx.symbolic_opset8]
ignore_errors = True

View File

@ -758,7 +758,7 @@ if __name__ == '__main__':
'py.typed',
'bin/*',
'test/*',
'__init__.pyi',
'_C/*.pyi',
'cuda/*.pyi',
'optim/*.pyi',
'autograd/*.pyi',

View File

@ -103,7 +103,6 @@ blacklist = [
'div_out',
'true_divide', 'true_divide_', 'true_divide_out',
'floor_divide', 'floor_divide_', 'floor_divide_out',
'dequantize',
]
@ -320,7 +319,7 @@ def generate_type_hints(fname, decls, namedtuples, is_tensor=False):
numargs = len(decl['arguments'])
vararg_pos = int(is_tensor)
have_vararg_version = (numargs > vararg_pos and
decl['arguments'][vararg_pos]['dynamic_type'] in {'IntArrayRef', 'TensorList'} and
decl['arguments'][vararg_pos]['dynamic_type'] in {'IntArrayRef'} and
(numargs == vararg_pos + 1 or python_args[vararg_pos + 1] == '*') and
(not is_tensor or decl['arguments'][0]['name'] == 'self'))
@ -328,14 +327,11 @@ def generate_type_hints(fname, decls, namedtuples, is_tensor=False):
if have_vararg_version:
# Two things come into play here: PyTorch has the "magic" that if the first and only positional argument
# is an IntArrayRef or TensorList, it will be used as a vararg variant.
# is an IntArrayRef, it will be used as a vararg variant.
# The following outputs the vararg variant, the "pass a list variant" is output above.
# The other thing is that in Python, the varargs are annotated with the element type, not the list type.
typelist = decl['arguments'][vararg_pos]['dynamic_type']
if typelist == 'IntArrayRef':
vararg_type = '_int'
else:
vararg_type = 'Tensor'
vararg_type = '_int'
# replace first argument and eliminate '*' if present
python_args = ((['self'] if is_tensor else []) + ['*' + decl['arguments'][vararg_pos]['name'] +
': ' + vararg_type] + python_args[vararg_pos + 2:])
@ -419,6 +415,9 @@ def gen_nn_functional(out):
}
write(out, 'torch/nn/functional.pyi', stubs, env)
stubs = CodeTemplate.from_file(os.path.join('torch', '_C', '_nn.pyi.in'))
write(out, 'torch/_C/_nn.pyi', stubs, env)
def gen_nn_pyi(out):
gen_nn_functional(out)
gen_nn_modules(out)
@ -485,7 +484,9 @@ def gen_pyi(declarations_path, out):
'def full(size: _size, fill_value: Number, *,'
' names: List[Union[str, None]], {}) -> Tensor: ...'
.format(FACTORY_PARAMS)],
'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...']
'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...'],
'nonzero': ['def nonzero(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...',
'def nonzero(input: Tensor, *, as_tuple: bool=...) -> Tensor: ...'],
})
for binop in ['mul', 'div', 'true_divide', 'floor_divide']:
unsorted_function_hints[binop].append(
@ -622,11 +623,14 @@ def gen_pyi(declarations_path, out):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# TODO: These are deprecated, maybe we shouldn't type hint them
legacy_class_hints = []
for c in ('DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage'):
legacy_class_hints.append('class {}(Storage): ...'.format(c))
legacy_storage_base_hints = []
for c in ('Double', 'Float', 'Long', 'Int',
'Short', 'Char', 'Byte', 'Bool',
'Half', 'BFloat16', 'ComplexDouble',
'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32'):
legacy_storage_base_hints.append('class {}StorageBase(object): ...'.format(c))
legacy_class_hints = []
for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'):
legacy_class_hints.append('class {}(Tensor): ...'.format(c))
@ -650,11 +654,15 @@ def gen_pyi(declarations_path, out):
'function_hints': function_hints,
'tensor_method_hints': tensor_method_hints,
'legacy_class_hints': legacy_class_hints,
'legacy_storage_base_hints': legacy_storage_base_hints,
'dtype_class_hints': dtype_class_hints,
}
TORCH_TYPE_STUBS = CodeTemplate.from_file(os.path.join('torch', '__init__.pyi.in'))
TORCH_C_TYPE_STUBS = CodeTemplate.from_file(os.path.join('torch', '_C', '__init__.pyi.in'))
TORCH_C_VARIABLE_FUNCTIONS_TYPE_STUBS = \
CodeTemplate.from_file(os.path.join('torch', '_C', '_VariableFunctions.pyi.in'))
write(out, 'torch/__init__.pyi', TORCH_TYPE_STUBS, env)
write(out, 'torch/_C/__init__.pyi', TORCH_C_TYPE_STUBS, env)
write(out, 'torch/_C/_VariableFunctions.pyi', TORCH_C_VARIABLE_FUNCTIONS_TYPE_STUBS, env)
gen_nn_pyi(out)

View File

@ -268,7 +268,8 @@ set(ModulesStubOut
${TORCH_SRC_DIR}/nn/modules/upsampling.pyi
)
add_custom_target(torch_python_stubs DEPENDS
"${TORCH_SRC_DIR}/__init__.pyi"
"${TORCH_SRC_DIR}/_C/__init__.pyi"
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi"
"${TORCH_SRC_DIR}/nn/functional.pyi"
${ModuleStubOut}
)
@ -276,7 +277,8 @@ add_custom_target(torch_python_stubs DEPENDS
add_dependencies(torch_python_stubs ATEN_CPU_FILES_GEN_TARGET)
add_custom_command(
OUTPUT
"${TORCH_SRC_DIR}/__init__.pyi"
"${TORCH_SRC_DIR}/_C/__init__.pyi"
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi"
"${TORCH_SRC_DIR}/nn/functional.pyi"
${ModuleStubOut}
COMMAND
@ -284,7 +286,8 @@ add_custom_command(
--declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
DEPENDS
"${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
"${TORCH_SRC_DIR}/__init__.pyi.in"
"${TORCH_SRC_DIR}/_C/__init__.pyi.in"
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi.in"
"${TORCH_SRC_DIR}/nn/functional.pyi.in"
${ModuleStubIn}
"${TOOLS_PATH}/pyi/gen_pyi.py"

View File

@ -0,0 +1,14 @@
# ${generated_comment}
from torch import Tensor, Generator, strided, memory_format, contiguous_format
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar
from torch._six import inf
from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout
import builtins
# REDUNDANT!
${namedtuple_defs}
${function_hints}

View File

@ -1,86 +0,0 @@
import torch
from typing import Optional, TypeVar, Callable, Any
from . import _nn as _nn
from . import _onnx as _onnx
T = TypeVar('T')
# Defined in torch/csrc/autograd/python_function.cpp
class _FunctionBase(object):
# TODO
...
# Defined in torch/csrc/autograd/python_legacy_variable.cpp
class _LegacyVariableBase(object):
def __init__(
self,
data: Optional['torch.Tensor']=...,
requires_grad: Optional[bool]=...,
volatile: Optional[bool]=...,
_grad_fn: Optional[_FunctionBase]=...
) -> None: ...
# Defined in torch/csrc/jit/python/init.cpp
def _jit_get_operation(op_name: str) -> Callable: ...
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule') -> 'torch.jit.ScriptModule': ...
# Defined in torch/csrc/Module.cpp
def _show_config() -> str: ...
def _parallel_info() -> str: ...
def _add_docstr(obj: T, doc_obj: str) -> T: ...
def _from_dlpack(data: Any) -> 'torch.Tensor': ...
def _to_dlpack(data: 'torch.Tensor') -> Any: ...
def _set_backcompat_broadcast_warn(arg: bool) -> None: ...
def _get_backcompat_broadcast_warn() -> bool: ...
def _set_backcompat_keepdim_warn(arg: bool) -> None: ...
def _get_backcompat_keepdim_warn() -> bool: ...
def _is_xnnpack_enabled() -> bool: ...
def _get_mkldnn_enabled() -> bool: ...
def _set_mkldnn_enabled(arg: bool) -> None: ...
has_openmp: bool
has_mkldnn: bool
has_mkl: bool
# Defined in tools/autograd/templates/python_torch_functions.cpp
# TODO: This is technically wrong
class _VariableFunctions(object):
# TODO
...
# Defined in torch/csrc/jit/python/script_init.cpp
class FileCheck(object):
# TODO
...
# Defined in torch/csrc/Generator.cpp
class Generator(object):
device: 'torch.device'
def get_state(self) -> 'torch.Tensor': ...
def set_state(self, _new_state: 'torch.Tensor') -> Generator: ...
def manual_seed(self, seed: int) -> Generator: ...
def seed(self) -> int: ...
def initial_seed(self) -> int: ...
# Defined in torch/csrc/utils/init.cpp
class BenchmarkConfig(object):
num_calling_threads: int
num_worker_threads: int
num_warmup_iters: int
num_iters: int
profiler_output_path: str
class BenchmarkExecutionStats(object):
latency_avg_ms: float
num_iters: int
class ThroughputBenchmark(object):
def __init__(self, module: Any) -> None: ...
def add_input(self, *args: Any, **kwargs: Any) -> None: ...
def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ...
# Defined in torch/csrc/autograd/python_variable.cpp
# This is gonna need to be code'genned.
class _TensorBase(object):
...

157
torch/_C/__init__.pyi.in Normal file
View File

@ -0,0 +1,157 @@
# ${generated_comment}
import torch
from torch import Tensor
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar, Type
from torch._six import inf
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Number
import builtins
from . import _nn as _nn
from . import _onnx as _onnx
from . import _VariableFunctions as _VariableFunctions
T = TypeVar('T')
# Defined in torch/csrc/Device.cpp
class device:
type: str
index: _int
@overload
def __init__(self, device: Union[_int, str]) -> None: ...
@overload
def __init__(self, type: str, index: _int) -> None: ...
# TODO: __reduce__
# Defined in torch/csrc/Size.cpp
class Size(Tuple[_int, ...]):
# TODO: numel, __reduce__
...
# Defined in torch/csrc/Dtype.cpp
class dtype:
# TODO: is_floating_point, is_complex, is_Signed, __reduce__
...
${dtype_class_hints}
# Defined in torch/csrc/Layout.cpp
class layout:
...
# Defined in torch/csrc/utils/tensor_layouts.cpp
strided : layout = ...
sparse_coo : layout = ...
# Defined in torch/csrc/MemoryFormat.cpp
class memory_format: ...
# Defined in torch/csrc/utils/tensor_memoryformats.cpp
contiguous_format: memory_format = ...
# Defined in torch/csrc/QScheme.cpp
class qscheme: ...
# Defined in torch/csrc/utils/tensor_qschemes.cpp
per_tensor_affine: qscheme = ...
# Defined in torch/csrc/generic/Storage.cpp
class Storage: ...
# Defined in torch/csrc/autograd/python_function.cpp
class _FunctionBase(object):
...
# Defined in torch/csrc/autograd/python_legacy_variable.cpp
class _LegacyVariableBase(object):
def __init__(
self,
data: Optional[Tensor]=...,
requires_grad: Optional[_bool]=...,
volatile: Optional[_bool]=...,
_grad_fn: Optional[_FunctionBase]=...
) -> None: ...
# Defined in torch/csrc/jit/python/init.cpp
def _jit_get_operation(op_name: str) -> Callable: ...
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule') -> 'torch.jit.ScriptModule': ...
# Defined in torch/csrc/Module.cpp
def _init_names(arg: Sequence[Type]) -> None: ...
def _show_config() -> str: ...
def _parallel_info() -> str: ...
def _add_docstr(obj: T, doc_obj: str) -> T: ...
def _from_dlpack(data: Any) -> Tensor: ...
def _to_dlpack(data: Tensor) -> Any: ...
def _set_backcompat_broadcast_warn(arg: _bool) -> None: ...
def _get_backcompat_broadcast_warn() -> _bool: ...
def _set_backcompat_keepdim_warn(arg: _bool) -> None: ...
def _get_backcompat_keepdim_warn() -> _bool: ...
def _is_xnnpack_enabled() -> _bool: ...
def _get_mkldnn_enabled() -> _bool: ...
def _set_mkldnn_enabled(arg: _bool) -> None: ...
def _set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
def _set_default_dtype(d: _dtype) -> None: ...
def _initExtension(shm_manager_path: str) -> None: ...
has_openmp: _bool
has_mkldnn: _bool
has_mkl: _bool
_GLIBCXX_USE_CXX11_ABI: _bool
# Defined in torch/csrc/jit/python/script_init.cpp
class FileCheck(object):
# TODO
...
# Defined in torch/csrc/Generator.cpp
class Generator(object):
device: _device
def __init__(self, device: Union[_device, str, None] = None) -> None: ...
def get_state(self) -> Tensor: ...
def set_state(self, _new_state: Tensor) -> Generator: ...
def manual_seed(self, seed: _int) -> Generator: ...
def seed(self) -> _int: ...
def initial_seed(self) -> _int: ...
# Defined in torch/csrc/utils/init.cpp
class BenchmarkConfig(object):
num_calling_threads: _int
num_worker_threads: _int
num_warmup_iters: _int
num_iters: _int
profiler_output_path: str
class BenchmarkExecutionStats(object):
latency_avg_ms: _float
num_iters: _int
class ThroughputBenchmark(object):
def __init__(self, module: Any) -> None: ...
def add_input(self, *args: Any, **kwargs: Any) -> None: ...
def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ...
# IDK if these are actually exposed here, hope they are
${namedtuple_defs}
# Defined in torch/csrc/generic/Storage.cpp
${legacy_storage_base_hints}
# TODO: where
${legacy_class_hints}
# Defined in torch/csrc/autograd/python_variable.cpp
class _TensorBase(object):
requires_grad: _bool
shape: Size
data: Tensor
names: List[str]
device: _device
dtype: _dtype
layout: _layout
${tensor_method_hints}

View File

@ -1,3 +1,5 @@
from typing import Callable
# Defined in tools/autograd/templates/python_nn_functions.cpp
class _nn(object):
...
${dispatched_hints}

View File

@ -4,6 +4,8 @@ import types
class VFModule(types.ModuleType):
vf: types.ModuleType
def __init__(self, name):
super(VFModule, self).__init__(name)
self.vf = torch._C._VariableFunctions

View File

@ -23,6 +23,8 @@ from ._utils_internal import get_file_path, prepare_multiprocessing_environment,
from .version import __version__
from ._six import string_classes as _string_classes
from typing import Set, Type
__all__ = [
'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type',
'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed', 'seed',
@ -114,10 +116,10 @@ if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \
if not hasattr(_dl_flags, 'RTLD_GLOBAL') or not hasattr(_dl_flags, 'RTLD_LAZY'):
try:
# next try if DLFCN exists
import DLFCN as _dl_flags
import DLFCN as _dl_flags # type: ignore
except ImportError:
# as a last attempt, use compile-time constants
import torch._dl as _dl_flags
import torch._dl as _dl_flags # type: ignore
old_flags = sys.getdlopenflags()
sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_LAZY)
from torch._C import *
@ -139,6 +141,11 @@ else:
_load_global_deps()
from torch._C import *
# Appease the type checker; ordinarily this binding is inserted by the
# torch._C module initialization code in C
if False:
import torch._C as _C
__all__ += [name for name in dir(_C)
if name[0] != '_' and
not name.endswith('Base')]
@ -311,7 +318,7 @@ _storage_classes = {
}
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
_tensor_classes = set()
_tensor_classes: Set[Type] = set()
################################################################################
@ -332,6 +339,13 @@ def manager_path():
_C._initExtension(manager_path())
del manager_path
# Appease the type checker: it can't deal with direct setting of globals().
# Note that we will see "too many" functions when reexporting this way; there
# is not a good way to fix this problem. Perhaps, try to redesign VariableFunctions
# so that this import is good enough
if False:
from torch._C._VariableFunctions import *
for name in dir(_C._VariableFunctions):
if name.startswith('__'):
continue

View File

@ -1,180 +0,0 @@
# ${generated_comment}
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence
from torch._six import inf
import builtins
# These identifiers are reexported from other modules. These modules
# are not mypy-clean yet, so in order to use this stub file usefully
# from mypy you will need to specify --follow-imports=silent.
# Not all is lost: these imports still enable IDEs like PyCharm to offer
# autocomplete.
#
# Note: Why does the syntax here look so strange? Import visibility
# rules in stubs are different from normal Python files! You must use
# 'from ... import ... as ...' syntax to cause an identifier to be
# exposed (or use a wildcard); regular syntax is not exposed.
from .random import set_rng_state as set_rng_state, get_rng_state as get_rng_state, \
manual_seed as manual_seed, initial_seed as initial_seed, seed as seed
from ._tensor_str import set_printoptions as set_printoptions
from .functional import *
from .serialization import save as save, load as load
from .autograd import no_grad as no_grad, enable_grad as enable_grad, \
set_grad_enabled as set_grad_enabled
from ._ops import ops
from ._classes import classes
from . import autograd as autograd
from . import cuda as cuda
from . import optim as optim
from . import nn as nn
from . import multiprocessing as multiprocessing
from . import sparse as sparse
from . import onnx as onnx
from . import jit as jit
from . import hub as hub
from . import random as random
from . import distributions as distributions
from . import testing as testing
from . import quantization as quantization
from . import __config__ as __config__
from . import __future__ as __future__
class dtype: ...
class layout: ...
strided : layout = ...
sparse_coo : layout = ...
class memory_format: ...
contiguous_format: memory_format = ...
class qscheme: ...
per_tensor_affine: qscheme = ...
# See https://github.com/python/mypy/issues/4146 for why these workarounds
# is necessary
_int = builtins.int
_float = builtins.float
_bool = builtins.bool
class device:
type: str
index: _int
@overload
def __init__(self, device: Union[_int, str]) -> None: ...
@overload
def __init__(self, type: str, index: _int) -> None: ...
class Size(Tuple[_int, ...]): ...
class Storage: ...
# See https://github.com/python/mypy/issues/4146 for why these workarounds
# is necessary
_dtype = dtype
_device = device
_qscheme = qscheme
_size = Union[Size, List[_int], Tuple[_int, ...]]
_layout = layout
# Meta-type for "numeric" things; matches our docs
Number = Union[builtins.int, builtins.float, builtins.bool]
${namedtuple_defs}
class Generator:
device: _device = ...
@overload
def __init__(self, device: Optional[_device]=None) -> None: ...
@overload
def __init__(self, device: Union[_int, str]) -> None: ...
# TODO: One downside of doing it this way, is direct use of
# torch.tensor.Tensor doesn't get type annotations. Nobody
# should really do that, so maybe this is not so bad.
class Tensor:
requires_grad: _bool = ...
grad: Optional[Tensor] = ...
data: Tensor = ...
names: List[str] = ...
@property
def dtype(self) -> _dtype: ...
@property
def shape(self) -> Size: ...
@property
def device(self) -> _device: ...
@property
def T(self) -> Tensor: ...
@property
def grad_fn(self) -> Optional[Any]: ...
@property
def ndim(self) -> _int: ...
@property
def layout(self) -> _layout: ...
${tensor_method_hints}
# Manually defined methods from torch/tensor.py
def __len__(self) -> _int: ...
def __iter__(self) -> Iterator[Tensor]: ...
def __contains__(self, item: Union[Tensor, Number]) -> _bool: ...
def register_hook(self, hook: Callable) -> Any: ...
def retain_grad(self) -> None: ...
def is_shared(self) -> _bool: ...
def share_memory_(self) -> None: ...
# NamedTensor requires several manually created annotations
def align_to(self, *names: Union[str, ellipsis]) -> Tensor: ...
def refine_names(self, *names: Union[str, ellipsis, None]) -> Tensor: ...
def rename(self, *names: Union[str, None], **rename_map: str) -> Tensor: ...
def unflatten(self, dim: Union[str, _int], namedshape: Sequence[Tuple[str, _int]]) -> Tensor: ...
# TODO: fill in the types for these, or otherwise figure out some
# way to not have to write these out again...
def nonzero(self, *, as_tuple=True): ...
def norm(self, p="fro", dim=None, keepdim=False): ...
def stft(self, n_fft, hop_length=None, win_length=None, window=None,
center=True, pad_mode='reflect', normalized=False, onesided=True): ...
def istft(self, n_fft, hop_length=None, win_length=None, window=None,
center=True, normalized=False, onesided=True, length=None): ...
def split(self, split_size, dim=0): ...
def unique(self, sorted=True, return_inverse=False, dim=None): ...
def unique_consecutive(self, sorted=True, return_inverse=False, return_counts=False, dim=None): ...
def lu(self, pivot=True, get_infos=False): ...
${function_hints}
${legacy_class_hints}
${dtype_class_hints}
# Pure Python functions defined in torch/__init__.py
def typename(obj) -> str: ...
def is_tensor(obj) -> _bool: ...
def is_storage(obj) -> _bool: ...
def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
def set_default_dtype(d : _dtype) -> None: ...
def manager_path() -> str: ...
def compiled_with_cxx11_abi() -> _bool: ...
# The return value of this function depends on the value of `as_tuple`,
# (similar to `unique`, `lu`, etc.); as such, it is not
# possible to type correctly
def nonzero(input: Tensor, *, out: Optional[Tensor]=None, as_tuple: Optional[_bool]=None): ...
# we can't auto generate hints for torch.dequantize because it will generate
# `dequantize(*tensors) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...`
# which overlaps with
# `dequantize(self: Tensor) -> Tensor: ...
@overload
def dequantize(self: Tensor) -> Tensor: ...
@overload
def dequantize(tensors: Union[Tuple[Tensor, ...], List[Tensor]]) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...

View File

@ -1,4 +1,5 @@
from .. import Tensor, _size
from torch import Tensor
from torch.types import _size
from typing import Any, Optional, Tuple, Dict, List, Callable
from .common_types import _ratio_any_t

View File

@ -1,6 +1,7 @@
from .module import Module
from typing import Optional
from ... import Tensor, _size
from torch import Tensor
from torch.types import _size
from ..common_types import _size_any_t, _maybe_indices_t, _size_1_t, _size_2_t, _size_3_t, _ratio_3_t, _ratio_2_t

View File

@ -1,6 +1,7 @@
from collections import namedtuple
from typing import Any, Optional, overload, Union, TypeVar, Tuple, Sequence
from ... import Tensor, _dtype, _device
from torch import Tensor
from torch.types import _dtype, _device
PackedSequence_ = namedtuple('PackedSequence', ['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices'])

View File

@ -1,7 +1,26 @@
import torch
from typing import Union, Sequence
from typing import Union, Sequence, List, Tuple
import builtins
# Convenience aliases for common composite types that we need
# to talk about in PyTorch
_TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]]
# In some cases, these basic types are shadowed by corresponding
# top-level values. The underscore variants let us refer to these
# types. See https://github.com/python/mypy/issues/4146 for why these
# workarounds is necessary
_int = builtins.int
_float = builtins.float
_bool = builtins.bool
_dtype = torch.dtype
_device = torch.device
_qscheme = torch.qscheme
_size = Union[torch.Size, List[_int], Tuple[_int, ...]]
_layout = torch.layout
# Meta-type for "numeric" things; matches our docs
Number = Union[builtins.int, builtins.float, builtins.bool]

View File

@ -1,6 +1,8 @@
from typing import Any, Callable, TypeVar, Generic, overload, Sequence, List, Optional
from . import Dataset, Sampler
from torch.utils.data._utils.worker import get_worker_info as get_worker_info
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
_worker_init_fn_t = Callable[[int], None]

View File

@ -6,12 +6,12 @@ T = TypeVar('T')
class Dataset(Generic[T_co]):
def __getitem__(self, index: int) -> T_co: ...
def __len__(self) -> int: ...
def __add__(self, other: T_co) -> 'ConcatDataset[T_co]': ...
# error: Cannot use a covariant type variable as a parameter
def __add__(self, other: T_co) -> 'ConcatDataset[T_co]': ... # type: ignore
class IterableDataset(Dataset[T_co]):
def __iter__(self) -> Iterable[T_co]: ...
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
tensors: List[Tensor]
@ -23,6 +23,9 @@ class ConcatDataset(Dataset[T_co]):
def __init__(self, datasets: Iterable[Dataset]) -> None: ...
class ChainDataset(Dataset[T_co]):
def __init__(self, datasets: Iterable[Dataset]) -> None: ...
class Subset(Dataset[T_co]):
dataset: Dataset[T_co]
indices: Sequence[int]

View File

@ -4,6 +4,6 @@ from . import Sampler, Dataset
T_co = TypeVar('T_co', covariant=True)
class DistributedSampler(Sampler[T_co]):
def __init__(self, dataset: Dataset, num_replicas: Optional[int]=..., rank: Optional[int]=..., shuffle: bool=...): ...
def __iter__(self) -> Iterator[int]: ...
def __iter__(self) -> Iterator[T_co]: ...
def __len__(self) -> int: ...
def set_epoch(self, epoch: int) -> None: ...