mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
More type stubs (#18511)
Summary: Added stubs for: * The `device` module * The `cuda` module * Parts of the `optim` module * Began adding stubs for the `autograd` module. I'll annotate more later but `no_grad` and friends are probably the most used exports from it so it seemed like a good place to start. This would close #16996, although comments on that issue reference other missing stubs so maybe it's worth keeping open as an umbrella issue. The big remaining missing package is `nn`. Also added a `py.typed` file so mypy will pick up on the type stubs. That closes #17639. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18511 Differential Revision: D14715053 Pulled By: ezyang fbshipit-source-id: 9e4882ac997063650e6ce47604b3eaf1232c61c9
This commit is contained in:
committed by
Facebook Github Bot
parent
aa23b8c664
commit
1b25fdbcd0
7
.flake8
7
.flake8
@ -1,8 +1,5 @@
|
|||||||
[flake8]
|
[flake8]
|
||||||
# Notably, P is not included here; this is the Facebook internal
|
select = B,C,E,F,P,T4,W,B9
|
||||||
# code for 'plow' (custom lint rules); they're not supported in OSS
|
|
||||||
# and so we don't enforce them.
|
|
||||||
select = B,C,E,F,T4,W,B9
|
|
||||||
max-line-length = 120
|
max-line-length = 120
|
||||||
# C408 ignored because we like the dict keyword argument syntax
|
# C408 ignored because we like the dict keyword argument syntax
|
||||||
# E501 is not flexible enough, we're using B950 instead
|
# E501 is not flexible enough, we're using B950 instead
|
||||||
@ -12,4 +9,4 @@ ignore =
|
|||||||
B007,B008,
|
B007,B008,
|
||||||
# these ignores are from flake8-comprehensions; please fix!
|
# these ignores are from flake8-comprehensions; please fix!
|
||||||
C400,C401,C402,C403,C404,C405,C407,C411,
|
C400,C401,C402,C403,C404,C405,C407,C411,
|
||||||
exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include,torch/__init__.pyi
|
exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include,*.pyi
|
||||||
|
4
setup.py
4
setup.py
@ -728,9 +728,13 @@ if __name__ == '__main__':
|
|||||||
entry_points=entry_points,
|
entry_points=entry_points,
|
||||||
package_data={
|
package_data={
|
||||||
'torch': [
|
'torch': [
|
||||||
|
'py.typed',
|
||||||
'bin/*',
|
'bin/*',
|
||||||
'test/*',
|
'test/*',
|
||||||
'__init__.pyi',
|
'__init__.pyi',
|
||||||
|
'cuda/*.pyi',
|
||||||
|
'optim/*.pyi',
|
||||||
|
'autograd/*.pyi',
|
||||||
'lib/*.so*',
|
'lib/*.so*',
|
||||||
'lib/*.dylib*',
|
'lib/*.dylib*',
|
||||||
'lib/*.dll',
|
'lib/*.dll',
|
||||||
|
@ -22,6 +22,8 @@ from .functional import *
|
|||||||
from .serialization import save as save, load as load
|
from .serialization import save as save, load as load
|
||||||
from .autograd import no_grad as no_grad, enable_grad as enable_grad, \
|
from .autograd import no_grad as no_grad, enable_grad as enable_grad, \
|
||||||
set_grad_enabled as set_grad_enabled
|
set_grad_enabled as set_grad_enabled
|
||||||
|
from . import cuda as cuda
|
||||||
|
from . import optim as optim
|
||||||
|
|
||||||
class dtype: ...
|
class dtype: ...
|
||||||
|
|
||||||
@ -35,7 +37,14 @@ _int = builtins.int
|
|||||||
_float = builtins.float
|
_float = builtins.float
|
||||||
|
|
||||||
class device:
|
class device:
|
||||||
def __init__(self, device: Union[_int, str, None]=None) -> None: ...
|
type: str
|
||||||
|
index: _int
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __init__(self, device: Union[_int, str]) -> None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __init__(self, type: str, index: _int) -> None: ...
|
||||||
|
|
||||||
class Generator: ...
|
class Generator: ...
|
||||||
|
|
||||||
|
45
torch/autograd/__init__.pyi
Normal file
45
torch/autograd/__init__.pyi
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from typing import Any, Callable, Union, Tuple, Sequence, Optional
|
||||||
|
from .. import Tensor
|
||||||
|
from .grad_mode import no_grad as no_grad, enable_grad as enable_grad, \
|
||||||
|
set_grad_enabled as set_grad_enabled
|
||||||
|
|
||||||
|
# TODO make Variable and Function more precise
|
||||||
|
class Variable:
|
||||||
|
...
|
||||||
|
|
||||||
|
class Function:
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: ...
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx: Any, *grad_outputs: Any) -> Any: ...
|
||||||
|
|
||||||
|
class NestedIOFunction(Function):
|
||||||
|
# The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
|
||||||
|
# superclass (Function) but are instance methods here, which mypy reports as incomptabile.
|
||||||
|
def backward(self, *gradients: Any) -> Any: ... # type: ignore
|
||||||
|
def forward(self, *args: Any) -> tuple: ... # type: ignore
|
||||||
|
def save_for_backward(self, *args: Any) -> None:...
|
||||||
|
def mark_dirty(self, *args: Any, **kwargs: Any) -> None:...
|
||||||
|
def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None: ...
|
||||||
|
def forward_extended(self, *input: Any) -> None:...
|
||||||
|
def backward_extended(self, *grad_output: Any) -> None: ...
|
||||||
|
|
||||||
|
# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
|
||||||
|
# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
|
||||||
|
# the '...' first argument of Callabe can be replaced with VarArg(Tensor).
|
||||||
|
# For now, we permit any input.
|
||||||
|
def gradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., raise_exception: bool=..., check_sparse_nnz: bool=...) -> bool: ...
|
||||||
|
def gradgradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., gen_non_contig_grad_outputs: bool=..., raise_exception: bool=...) -> bool: ...
|
||||||
|
|
||||||
|
class detect_anomaly:
|
||||||
|
def __enter__(self) -> None: ...
|
||||||
|
def __exit__(self, *args: Any) -> bool: ...
|
||||||
|
|
||||||
|
class set_detect_anomaly:
|
||||||
|
def __init__(self, mode: bool) -> None: ...
|
||||||
|
def __enter__(self) -> None:...
|
||||||
|
def __exit__(self, *args: Any) -> bool: ...
|
||||||
|
|
||||||
|
_TensorOrTensors = Union[Tensor, Sequence[Tensor]]
|
||||||
|
def backward(tensors: _TensorOrTensors, grad_tensors: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=...) -> None: ...
|
||||||
|
def grad(outputs: _TensorOrTensors, inputs: _TensorOrTensors, grad_outputs: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=..., only_inputs: bool=..., allow_unused: bool=...) -> Tuple[Tensor, ...]: ...
|
21
torch/autograd/grad_mode.pyi
Normal file
21
torch/autograd/grad_mode.pyi
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from typing import Any, Callable, TypeVar
|
||||||
|
|
||||||
|
# Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
|
||||||
|
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
|
||||||
|
FuncType = Callable[..., Any]
|
||||||
|
T = TypeVar('T', bound=FuncType)
|
||||||
|
|
||||||
|
class no_grad:
|
||||||
|
def __enter__(self) -> None: ...
|
||||||
|
def __exit__(self, *args: Any) -> bool: ...
|
||||||
|
def __call__(self, func: T) -> T: ...
|
||||||
|
|
||||||
|
class enable_grad:
|
||||||
|
def __enter__(self) -> None: ...
|
||||||
|
def __exit__(self, *args: Any) -> bool: ...
|
||||||
|
def __call__(self, func: T) -> T: ...
|
||||||
|
|
||||||
|
class set_grad_enabled:
|
||||||
|
def __init__(self, mode: bool) -> None: ...
|
||||||
|
def __enter__(self) -> None: ...
|
||||||
|
def __exit__(self, *args: Any) -> bool: ...
|
41
torch/cuda/__init__.pyi
Normal file
41
torch/cuda/__init__.pyi
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
import ctypes
|
||||||
|
from .. import device as _device
|
||||||
|
|
||||||
|
def is_available() -> bool: ...
|
||||||
|
def init() -> None: ...
|
||||||
|
|
||||||
|
class cudaStatus:
|
||||||
|
SUCCESS: int
|
||||||
|
ERROR_NOT_READY: int
|
||||||
|
|
||||||
|
class CudaError:
|
||||||
|
def __init__(self, code: int) -> None: ...
|
||||||
|
|
||||||
|
class _CudaDeviceProperties:
|
||||||
|
name: str
|
||||||
|
major: int
|
||||||
|
minor: int
|
||||||
|
multi_processor_count: int
|
||||||
|
total_memory: int
|
||||||
|
is_integrated: int
|
||||||
|
is_multi_gpu_board: int
|
||||||
|
|
||||||
|
_device_t = Union[_device, int]
|
||||||
|
|
||||||
|
def check_error(res: int) -> None: ...
|
||||||
|
def device_count() -> int: ...
|
||||||
|
def empty_cache() -> None: ...
|
||||||
|
def set_device(device: _device_t) -> None: ...
|
||||||
|
def get_device_capability(device: Optional[_device_t]=...) -> Tuple[int, int]: ...
|
||||||
|
def get_device_name(device: Optional[_device_t]=...) -> str: ...
|
||||||
|
def get_device_properties(device: _device_t) -> _CudaDeviceProperties: ...
|
||||||
|
def current_device() -> int: ...
|
||||||
|
def memory_allocated(device: Optional[_device_t]=...) -> int: ...
|
||||||
|
def max_memory_allocated(device: Optional[_device_t]=...) -> int: ...
|
||||||
|
def reset_max_memory_allocated(device: Optional[_device_t]=...) -> None: ...
|
||||||
|
def memory_cached(device: Optional[_device_t]=...) -> int: ...
|
||||||
|
def max_memory_cached(device: Optional[_device_t]=...) -> int: ...
|
||||||
|
def reset_max_memory_cached(device: Optional[_device_t]=...) -> None: ...
|
||||||
|
def cudart() -> ctypes.CDLL: ...
|
||||||
|
def find_cuda_windows_lib() -> Optional[ctypes.CDLL]: ...
|
3
torch/optim/__init__.pyi
Normal file
3
torch/optim/__init__.pyi
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .sgd import SGD as SGD
|
||||||
|
from .adam import Adam as Adam
|
||||||
|
from . import lr_scheduler as lr_scheduler
|
5
torch/optim/adam.pyi
Normal file
5
torch/optim/adam.pyi
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
from .optimizer import _params_t, Optimizer
|
||||||
|
|
||||||
|
class Adam(Optimizer):
|
||||||
|
def __init__(self, params: _params_t, lr: float=..., betas: Tuple[float, float]=..., eps: float=..., weight_decay: float=..., amsgrad: bool = ...) -> None: ...
|
32
torch/optim/lr_scheduler.pyi
Normal file
32
torch/optim/lr_scheduler.pyi
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from typing import Iterable, Any, Optional
|
||||||
|
from .optimizer import Optimizer
|
||||||
|
|
||||||
|
class _LRScheduler:
|
||||||
|
def __init__(self, optimizer: Optimizer, last_epoch: int=...) -> None: ...
|
||||||
|
def state_dict(self) -> dict: ...
|
||||||
|
def load_state_dict(self, state_dict: dict) -> None: ...
|
||||||
|
def get_lr(self) -> float: ...
|
||||||
|
def step(self, epoch: int) -> None: ...
|
||||||
|
|
||||||
|
class LambdaLR(_LRScheduler):
|
||||||
|
def __init__(self, optimizer: Optimizer, lr_lambda: float, last_epoch: int=...) -> None: ...
|
||||||
|
|
||||||
|
class StepLR(_LRScheduler):
|
||||||
|
def __init__(self, optimizer: Optimizer, step_size: int, gamma: float=..., last_epoch: int=...) -> None:...
|
||||||
|
|
||||||
|
class MultiStepLR(_LRScheduler):
|
||||||
|
def __init__(self, optimizer: Optimizer, milestones: Iterable[int], gamma: float=..., last_epoch: int=...) -> None: ...
|
||||||
|
|
||||||
|
class ExponentialLR(_LRScheduler):
|
||||||
|
def __init__(self, optimizer: Optimizer, gamma: float, last_epoch: int=...) -> None: ...
|
||||||
|
|
||||||
|
class CosineAnnealingLr(_LRScheduler):
|
||||||
|
def __init__(self, optimizer: Optimizer, T_max: int, eta_min: float, last_epoch: int=...) -> None: ...
|
||||||
|
|
||||||
|
class ReduceLROnPlateau:
|
||||||
|
in_cooldown: bool
|
||||||
|
|
||||||
|
def __init__(self, optimizer: Optimizer, mode: str=..., factor: float=..., patience: int=..., verbose: bool=..., threshold: float=..., threshold_mode: str=..., cooldown: int=..., min_lr: float=..., eps: float=...) -> None: ...
|
||||||
|
def step(self, metrics: Any, epoch: Optional[int]=...) -> None: ...
|
||||||
|
def state_dict(self) -> dict: ...
|
||||||
|
def load_state_dict(self, state_dict: dict): ...
|
12
torch/optim/optimizer.pyi
Normal file
12
torch/optim/optimizer.pyi
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from typing import Iterable, Union, Callable, Optional
|
||||||
|
from .. import Tensor
|
||||||
|
|
||||||
|
_params_t = Union[Iterable[Tensor], dict]
|
||||||
|
|
||||||
|
class Optimizer:
|
||||||
|
def __init__(self, params: _params_t) -> None: ...
|
||||||
|
def state_dict(self) -> dict: ...
|
||||||
|
def load_state_dict(self, state_dict: dict) -> None: ...
|
||||||
|
def zero_grad(self) -> None: ...
|
||||||
|
def step(self, closure: Optional[Callable[[], float]]=...) -> None: ...
|
||||||
|
def add_param_group(self, param_group: dict) -> None: ...
|
4
torch/optim/sgd.pyi
Normal file
4
torch/optim/sgd.pyi
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .optimizer import _params_t, Optimizer
|
||||||
|
|
||||||
|
class SGD(Optimizer):
|
||||||
|
def __init__(self, params: _params_t, lr: float, momentum: float=..., dampening: float=..., weight_decay:float=..., nesterov:bool=...) -> None: ...
|
0
torch/py.typed
Normal file
0
torch/py.typed
Normal file
Reference in New Issue
Block a user