mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157011 Approved by: https://github.com/Skylion007
60 lines
2.0 KiB
Python
60 lines
2.0 KiB
Python
import copy
|
|
import warnings
|
|
from collections.abc import Iterable, Iterator, Sized
|
|
from typing import TypeVar
|
|
|
|
from torch.utils.data.datapipes.datapipe import IterDataPipe
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
__all__ = ["IterableWrapperIterDataPipe"]
|
|
|
|
|
|
class IterableWrapperIterDataPipe(IterDataPipe[_T]):
|
|
r"""
|
|
Wraps an iterable object to create an IterDataPipe.
|
|
|
|
Args:
|
|
iterable: Iterable object to be wrapped into an IterDataPipe
|
|
deepcopy: Option to deepcopy input iterable object for each
|
|
iterator. The copy is made when the first element is read in ``iter()``.
|
|
|
|
.. note::
|
|
If ``deepcopy`` is explicitly set to ``False``, users should ensure
|
|
that the data pipeline doesn't contain any in-place operations over
|
|
the iterable instance to prevent data inconsistency across iterations.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> from torchdata.datapipes.iter import IterableWrapper
|
|
>>> dp = IterableWrapper(range(10))
|
|
>>> list(dp)
|
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
|
"""
|
|
|
|
def __init__(self, iterable: Iterable[_T], deepcopy: bool = True) -> None:
|
|
self.iterable = iterable
|
|
self.deepcopy = deepcopy
|
|
|
|
def __iter__(self) -> Iterator[_T]:
|
|
source_data = self.iterable
|
|
if self.deepcopy:
|
|
try:
|
|
source_data = copy.deepcopy(self.iterable)
|
|
# For the case that data cannot be deep-copied,
|
|
# all in-place operations will affect iterable variable.
|
|
# When this DataPipe is iterated second time, it will
|
|
# yield modified items.
|
|
except TypeError:
|
|
warnings.warn(
|
|
"The input iterable can not be deepcopied, "
|
|
"please be aware of in-place modification would affect source data."
|
|
)
|
|
yield from source_data
|
|
|
|
def __len__(self) -> int:
|
|
if isinstance(self.iterable, Sized):
|
|
return len(self.iterable)
|
|
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|