mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62768 This is part of TorchArrow DF support preparation, separating it to multiple PRs to simplify review process. Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D30149090 Pulled By: VitalyFedyunin fbshipit-source-id: a36b5ff56e2ac6b06060014d4cd41b487754acb8
112 lines
4.3 KiB
Python
112 lines
4.3 KiB
Python
import random
|
|
|
|
from torch.utils.data import IterDataPipe, Sampler, SequentialSampler, functional_datapipe
|
|
from typing import TypeVar, Type, Iterator, Sized, Optional, Tuple, Dict, List
|
|
|
|
T_co = TypeVar('T_co', covariant=True)
|
|
|
|
|
|
class SamplerIterDataPipe(IterDataPipe[T_co]):
|
|
r""" :class:`SamplerIterDataPipe`.
|
|
|
|
Iterable DataPipe to generate sample elements.
|
|
args:
|
|
datapipe: IterDataPipe sampled from
|
|
sampler: Sampler class to genereate sample elements from input DataPipe.
|
|
Default is :class:`SequentialSampler` for IterDataPipe
|
|
"""
|
|
datapipe: IterDataPipe
|
|
sampler: Sampler
|
|
|
|
def __init__(self,
|
|
datapipe: IterDataPipe,
|
|
sampler: Type[Sampler] = SequentialSampler,
|
|
sampler_args: Optional[Tuple] = None,
|
|
sampler_kwargs: Optional[Dict] = None
|
|
) -> None:
|
|
assert isinstance(datapipe, Sized), \
|
|
"Sampler class requires input datapipe implemented `__len__`"
|
|
super().__init__()
|
|
self.datapipe = datapipe
|
|
self.sampler_args = () if sampler_args is None else sampler_args
|
|
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
|
|
# https://github.com/python/mypy/pull/9629 will solve
|
|
self.sampler = sampler(data_source=self.datapipe, *self.sampler_args, **self.sampler_kwargs) # type: ignore[misc]
|
|
|
|
def __iter__(self) -> Iterator[T_co]:
|
|
return iter(self.sampler)
|
|
|
|
def __len__(self) -> int:
|
|
# Dataset has been tested as `Sized`
|
|
if isinstance(self.sampler, Sized) and len(self.sampler) >= 0:
|
|
return len(self.sampler)
|
|
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|
|
|
|
|
|
@functional_datapipe('shuffle')
|
|
class ShuffleIterDataPipe(IterDataPipe[T_co]):
|
|
r""" :class:`ShuffleIterDataPipe`
|
|
|
|
Iterable DataPipe to shuffle the input DataPipe with a buffer. The buffer
|
|
with `buffer_size` is filled with elements from the datapipe first. Then,
|
|
each item will be yielded from the buffer by reservoir sampling via iterator.
|
|
|
|
`buffer_size` is required to be larger than 0. For `buffer_size == 1`, the
|
|
datapipe is not shuffled. In order to fully shuffle all elements from datapipe,
|
|
`buffer_size` is required to be greater than or equal to the size of datapipe.
|
|
|
|
When it is used with :class:`~torch.utils.data.DataLoader`, the methods to
|
|
set up random seed are different based on :attr:`num_workers`.
|
|
|
|
For single-process mode (:attr:`num_workers == 0`), the random seed is set before
|
|
the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
|
|
mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed
|
|
for each worker process.
|
|
|
|
args:
|
|
datapipe: The IterDataPipe being shuffled
|
|
buffer_size: The buffer size for shuffling (default to 10000)
|
|
unbatch_level: Specifies if it necessary to unbatch source data before
|
|
applying the shuffle
|
|
"""
|
|
datapipe: IterDataPipe[T_co]
|
|
buffer_size: int
|
|
_buffer: List[T_co]
|
|
|
|
def __init__(self,
|
|
datapipe: IterDataPipe[T_co],
|
|
*,
|
|
buffer_size: int = 10000,
|
|
unbatch_level: int = 0
|
|
) -> None:
|
|
super().__init__()
|
|
assert buffer_size > 0, "buffer_size should be larger than 0"
|
|
if unbatch_level == 0:
|
|
self.datapipe = datapipe
|
|
else:
|
|
self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level)
|
|
self.buffer_size = buffer_size
|
|
self._buffer = []
|
|
|
|
def buffer_replace(self, x):
|
|
idx = random.randint(0, self.buffer_size - 1)
|
|
val = self._buffer[idx]
|
|
self._buffer[idx] = x
|
|
return val
|
|
|
|
def __iter__(self) -> Iterator[T_co]:
|
|
# TODO: Buffer is global, should be per __iter__ !!!
|
|
for x in self.datapipe:
|
|
if len(self._buffer) == self.buffer_size:
|
|
yield self.buffer_replace(x)
|
|
else:
|
|
self._buffer.append(x)
|
|
random.shuffle(self._buffer)
|
|
while self._buffer:
|
|
yield self._buffer.pop()
|
|
|
|
def __len__(self) -> int:
|
|
if isinstance(self.datapipe, Sized):
|
|
return len(self.datapipe)
|
|
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|