mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71311 Test Plan: Imported from OSS Reviewed By: NivekT Differential Revision: D33584552 Pulled By: ejguan fbshipit-source-id: 52324faf5547f9f77582ec170ec91ce3114cfc61
182 lines
6.3 KiB
Python
182 lines
6.3 KiB
Python
from typing import Callable, Iterator, Sized, TypeVar
|
|
|
|
from torch.utils.data import IterDataPipe, _utils, functional_datapipe
|
|
from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE, check_lambda_fn
|
|
|
|
if DILL_AVAILABLE:
|
|
import dill
|
|
dill.extend(use_dill=False)
|
|
|
|
T_co = TypeVar("T_co", covariant=True)
|
|
|
|
|
|
@functional_datapipe("map")
|
|
class MapperIterDataPipe(IterDataPipe[T_co]):
|
|
r""":class:`MapperIterDataPipe`.
|
|
|
|
Iterable DataPipe to run a function over each item from the source DataPipe.
|
|
The function can be any regular python function or partial object. Lambda
|
|
function is not recommended as it is not supported by pickle.
|
|
|
|
Args:
|
|
datapipe: Source Iterable DataPipe
|
|
fn: Function called over each item
|
|
input_col: Index or indices of data which `fn` is applied
|
|
- None as default to apply `fn` to the data directly.
|
|
- Integer(s) is used for list/tuple.
|
|
- Key(s) is used for dict.
|
|
output_col: Index of data where result of `fn` is placed. `output_col` can be specified only when `input_col` is not None
|
|
- None as default to replace the index that `input_col` specified;
|
|
For `input_col` with multiple indices, the left-most one is used, and other indices will be removed.
|
|
- Integer is used for list/tuple. -1 represents to append result at the end.
|
|
- Key is used for dict. New key is acceptable.
|
|
"""
|
|
datapipe: IterDataPipe
|
|
fn: Callable
|
|
|
|
def __init__(
|
|
self,
|
|
datapipe: IterDataPipe,
|
|
fn: Callable,
|
|
input_col=None,
|
|
output_col=None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.datapipe = datapipe
|
|
|
|
check_lambda_fn(fn)
|
|
self.fn = fn # type: ignore[assignment]
|
|
|
|
self.input_col = input_col
|
|
if input_col is None and output_col is not None:
|
|
raise ValueError("`output_col` must be None when `input_col` is None.")
|
|
if isinstance(output_col, (list, tuple)):
|
|
if len(output_col) > 1:
|
|
raise ValueError("`output_col` must be a single-element list or tuple")
|
|
output_col = output_col[0]
|
|
self.output_col = output_col
|
|
|
|
def _apply_fn(self, data):
|
|
if self.input_col is None and self.output_col is None:
|
|
return self.fn(data)
|
|
|
|
if self.input_col is None:
|
|
res = self.fn(data)
|
|
elif isinstance(self.input_col, (list, tuple)):
|
|
args = tuple(data[col] for col in self.input_col)
|
|
res = self.fn(*args)
|
|
else:
|
|
res = self.fn(data[self.input_col])
|
|
|
|
# Copy tuple to list and run in-place modification because tuple is immutable.
|
|
if isinstance(data, tuple):
|
|
t_flag = True
|
|
data = list(data)
|
|
else:
|
|
t_flag = False
|
|
|
|
if self.output_col is None:
|
|
if isinstance(self.input_col, (list, tuple)):
|
|
data[self.input_col[0]] = res
|
|
for idx in sorted(self.input_col[1:], reverse=True):
|
|
del data[idx]
|
|
else:
|
|
data[self.input_col] = res
|
|
else:
|
|
if self.output_col == -1:
|
|
data.append(res)
|
|
else:
|
|
data[self.output_col] = res
|
|
|
|
# Convert list back to tuple
|
|
return tuple(data) if t_flag else data
|
|
|
|
def __iter__(self) -> Iterator[T_co]:
|
|
for data in self.datapipe:
|
|
yield self._apply_fn(data)
|
|
|
|
def __len__(self) -> int:
|
|
if isinstance(self.datapipe, Sized):
|
|
return len(self.datapipe)
|
|
raise TypeError(
|
|
"{} instance doesn't have valid length".format(type(self).__name__)
|
|
)
|
|
|
|
def __getstate__(self):
|
|
if IterDataPipe.getstate_hook is not None:
|
|
return IterDataPipe.getstate_hook(self)
|
|
|
|
if DILL_AVAILABLE:
|
|
dill_function = dill.dumps(self.fn)
|
|
else:
|
|
dill_function = self.fn
|
|
state = (
|
|
self.datapipe,
|
|
dill_function,
|
|
self.input_col,
|
|
self.output_col,
|
|
)
|
|
return state
|
|
|
|
def __setstate__(self, state):
|
|
(
|
|
self.datapipe,
|
|
dill_function,
|
|
self.input_col,
|
|
self.output_col,
|
|
) = state
|
|
if DILL_AVAILABLE:
|
|
self.fn = dill.loads(dill_function) # type: ignore[assignment]
|
|
else:
|
|
self.fn = dill_function # type: ignore[assignment]
|
|
|
|
|
|
@functional_datapipe("collate")
|
|
class CollatorIterDataPipe(MapperIterDataPipe):
|
|
r""":class:`CollatorIterDataPipe`.
|
|
|
|
Iterable DataPipe to collate samples from DataPipe to Tensor(s) by a custom collate function,
|
|
which defaults to `torch.utils.data.default_collate` if it is not specified.
|
|
|
|
.. note::
|
|
While writing a custom collate function, you can import `torch.utils.data.default_collate` for the
|
|
default behavior and `functools.partial` to specify any additional arguments.
|
|
|
|
Args:
|
|
datapipe: Iterable DataPipe being collated
|
|
collate_fn: Customized collate function to collect and combine data or a batch of data.
|
|
Default function collates to Tensor(s) based on data type.
|
|
|
|
Example: Convert integer data to float Tensor
|
|
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
|
|
... def __init__(self, start, end):
|
|
... super(MyIterDataPipe).__init__()
|
|
... assert end > start, "this example code only works with end >= start"
|
|
... self.start = start
|
|
... self.end = end
|
|
...
|
|
... def __iter__(self):
|
|
... return iter(range(self.start, self.end))
|
|
...
|
|
... def __len__(self):
|
|
... return self.end - self.start
|
|
...
|
|
>>> ds = MyIterDataPipe(start=3, end=7)
|
|
>>> print(list(ds))
|
|
[3, 4, 5, 6]
|
|
|
|
>>> def collate_fn(batch):
|
|
... return torch.tensor(batch, dtype=torch.float)
|
|
...
|
|
>>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
|
|
>>> print(list(collated_ds))
|
|
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
datapipe: IterDataPipe,
|
|
collate_fn: Callable = _utils.collate.default_collate,
|
|
) -> None:
|
|
super().__init__(datapipe, fn=collate_fn)
|