mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Type annotations for util.data
. (#18963)
Summary: I haven't had a chance to rigorously try these out yet so don't merge yet. Closes #18725. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18963 Differential Revision: D14832897 Pulled By: ezyang fbshipit-source-id: 4780e7a34126bc66ddbfd9d808dfc9e0edd77e68
This commit is contained in:
committed by
Facebook Github Bot
parent
a2ac260524
commit
0565141728
1
setup.py
1
setup.py
@ -735,6 +735,7 @@ if __name__ == '__main__':
|
||||
'cuda/*.pyi',
|
||||
'optim/*.pyi',
|
||||
'autograd/*.pyi',
|
||||
'utils/data/*.pyi',
|
||||
'lib/*.so*',
|
||||
'lib/*.dylib*',
|
||||
'lib/*.dll',
|
||||
|
6
torch/utils/data/__init__.pyi
Normal file
6
torch/utils/data/__init__.pyi
Normal file
@ -0,0 +1,6 @@
|
||||
from .sampler import Sampler as Sampler, SequentialSampler as SequentialSampler, RandomSampler as RandomSampler, \
|
||||
SubsetRandomSampler as SubsetRandomSampler, WeightedRandomSampler as WeightedRandomSampler, BatchSampler as BatchSampler
|
||||
from .distributed import DistributedSampler as DistributedSampler
|
||||
from .dataset import Dataset as Dataset, TensorDataset as TensorDataset, ConcatDataset as ConcatDataset, \
|
||||
Subset as Subset, random_split as random_split
|
||||
from .dataloader import DataLoader as DataLoader
|
39
torch/utils/data/dataloader.pyi
Normal file
39
torch/utils/data/dataloader.pyi
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Any, Callable, TypeVar, Generic, overload, Sequence, List
|
||||
from . import Dataset, Sampler
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
T = TypeVar('T')
|
||||
_worker_init_fn_t = Callable[[int], None]
|
||||
|
||||
# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
|
||||
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
|
||||
# See https://github.com/python/mypy/issues/3737.
|
||||
_collate_fn_t = Callable[[List[T]], Any]
|
||||
|
||||
class DataLoader(Generic[T_co]):
|
||||
dataset: Dataset[T_co]
|
||||
batch_size: int
|
||||
num_workers: int
|
||||
pin_memory: bool
|
||||
drop_last: bool
|
||||
timeout: float
|
||||
|
||||
@overload
|
||||
def __init__(self, dataset: Dataset[T_co], batch_size: int=..., shuffle: bool=..., sampler: Sampler[int]=...,
|
||||
num_workers: int=..., collate_fn: _collate_fn_t=..., pin_memory: bool=...,
|
||||
drop_last: bool=..., timeout: float=..., worker_init_fn: _worker_init_fn_t=...) -> None: ...
|
||||
@overload
|
||||
def __init__(self, dataset: Dataset[T_co], batch_sampler: Sampler[Sequence[int]]=..., num_workers: int=...,
|
||||
collate_fn: _collate_fn_t=..., pin_memory: bool=..., timeout: float=...,
|
||||
worker_init_fn: _worker_init_fn_t=...) -> None: ...
|
||||
|
||||
def __len__(self) -> int: ...
|
||||
# We quote '_DataLoaderIter' since it isn't defined yet and the definition can't be moved up since
|
||||
# '_DataLoaderIter' references 'DataLoader'. Pending updates of PEP 484 will fix this.
|
||||
def __iter__(self) -> '_DataLoaderIter':...
|
||||
|
||||
class _DataLoaderIter:
|
||||
def __init__(self, loader: DataLoader) -> None:...
|
||||
def __len__(self) -> int: ...
|
||||
def __iter__(self) -> _DataLoaderIter: ...
|
||||
def __next__(self) -> Any: ...
|
28
torch/utils/data/dataset.pyi
Normal file
28
torch/utils/data/dataset.pyi
Normal file
@ -0,0 +1,28 @@
|
||||
from typing import TypeVar, Generic, Iterable, Sequence, List, Tuple
|
||||
from ... import Tensor
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
T = TypeVar('T')
|
||||
class Dataset(Generic[T_co]):
|
||||
def __getitem__(self, index: int) -> T_co: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __add__(self, other: T_co) -> 'ConcatDataset[T_co]': ...
|
||||
|
||||
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
|
||||
tensors: List[Tensor]
|
||||
|
||||
def __init__(self, *tensors: Tensor) -> None: ...
|
||||
|
||||
class ConcatDataset(Dataset[T_co]):
|
||||
datasets: List[Dataset[T_co]]
|
||||
cumulative_sizes: List[int]
|
||||
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None: ...
|
||||
|
||||
class Subset(Dataset[T_co]):
|
||||
dataset: Dataset[T_co]
|
||||
indices: Sequence[int]
|
||||
|
||||
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: ...
|
||||
|
||||
def random_split(dataset: Dataset[T], lengths: Sequence[int]) -> List[Subset[T]]: ...
|
9
torch/utils/data/distributed.pyi
Normal file
9
torch/utils/data/distributed.pyi
Normal file
@ -0,0 +1,9 @@
|
||||
from typing import TypeVar, Optional, Iterable
|
||||
from . import Sampler, Dataset
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
class DistributedSampler(Sampler[T_co]):
|
||||
def __init__(self, dataset: Dataset, num_replicas: Optional[int]=..., rank: Optional[int]=...): ...
|
||||
def __iter__(self) -> Iterable[int]: ...
|
||||
def __len__(self) -> int: ...
|
||||
def set_epoch(self, epoch: int) -> None: ...
|
24
torch/utils/data/sampler.pyi
Normal file
24
torch/utils/data/sampler.pyi
Normal file
@ -0,0 +1,24 @@
|
||||
from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
class Sampler(Generic[T_co]):
|
||||
def __init__(self, data_source: Sized) -> None: ...
|
||||
def __iter__(self) -> Iterator[T_co]: ...
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
class SequentialSampler(Sampler[int]):
|
||||
pass
|
||||
|
||||
class RandomSampler(Sampler[int]):
|
||||
num_samples: int
|
||||
|
||||
def __init__(self, data_source: Sized, replacement: bool=..., num_samples: Optional[int]=...) -> None: ...
|
||||
|
||||
class SubsetRandomSampler(Sampler[int]):
|
||||
def __init__(self, indices: Sequence[int]) -> None: ...
|
||||
|
||||
class WeightedRandomSampler(Sampler[int]):
|
||||
def __init__(self, weights: Sequence[float], num_samples: int, replacement: bool=...) -> None: ...
|
||||
|
||||
class BatchSampler(Sampler[List[int]]):
|
||||
def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None: ...
|
Reference in New Issue
Block a user