Files
pytorch/torch/utils/data/datapipes/iter/callable.py
Erjia Guan fd9e08df5d Make Demux serializable with lambda function (#71311)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71311

Test Plan: Imported from OSS

Reviewed By: NivekT

Differential Revision: D33584552

Pulled By: ejguan

fbshipit-source-id: 52324faf5547f9f77582ec170ec91ce3114cfc61
2022-01-18 06:47:54 -08:00

182 lines
6.3 KiB
Python

from typing import Callable, Iterator, Sized, TypeVar
from torch.utils.data import IterDataPipe, _utils, functional_datapipe
from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE, check_lambda_fn
if DILL_AVAILABLE:
import dill
dill.extend(use_dill=False)
T_co = TypeVar("T_co", covariant=True)
@functional_datapipe("map")
class MapperIterDataPipe(IterDataPipe[T_co]):
r""":class:`MapperIterDataPipe`.
Iterable DataPipe to run a function over each item from the source DataPipe.
The function can be any regular python function or partial object. Lambda
function is not recommended as it is not supported by pickle.
Args:
datapipe: Source Iterable DataPipe
fn: Function called over each item
input_col: Index or indices of data which `fn` is applied
- None as default to apply `fn` to the data directly.
- Integer(s) is used for list/tuple.
- Key(s) is used for dict.
output_col: Index of data where result of `fn` is placed. `output_col` can be specified only when `input_col` is not None
- None as default to replace the index that `input_col` specified;
For `input_col` with multiple indices, the left-most one is used, and other indices will be removed.
- Integer is used for list/tuple. -1 represents to append result at the end.
- Key is used for dict. New key is acceptable.
"""
datapipe: IterDataPipe
fn: Callable
def __init__(
self,
datapipe: IterDataPipe,
fn: Callable,
input_col=None,
output_col=None,
) -> None:
super().__init__()
self.datapipe = datapipe
check_lambda_fn(fn)
self.fn = fn # type: ignore[assignment]
self.input_col = input_col
if input_col is None and output_col is not None:
raise ValueError("`output_col` must be None when `input_col` is None.")
if isinstance(output_col, (list, tuple)):
if len(output_col) > 1:
raise ValueError("`output_col` must be a single-element list or tuple")
output_col = output_col[0]
self.output_col = output_col
def _apply_fn(self, data):
if self.input_col is None and self.output_col is None:
return self.fn(data)
if self.input_col is None:
res = self.fn(data)
elif isinstance(self.input_col, (list, tuple)):
args = tuple(data[col] for col in self.input_col)
res = self.fn(*args)
else:
res = self.fn(data[self.input_col])
# Copy tuple to list and run in-place modification because tuple is immutable.
if isinstance(data, tuple):
t_flag = True
data = list(data)
else:
t_flag = False
if self.output_col is None:
if isinstance(self.input_col, (list, tuple)):
data[self.input_col[0]] = res
for idx in sorted(self.input_col[1:], reverse=True):
del data[idx]
else:
data[self.input_col] = res
else:
if self.output_col == -1:
data.append(res)
else:
data[self.output_col] = res
# Convert list back to tuple
return tuple(data) if t_flag else data
def __iter__(self) -> Iterator[T_co]:
for data in self.datapipe:
yield self._apply_fn(data)
def __len__(self) -> int:
if isinstance(self.datapipe, Sized):
return len(self.datapipe)
raise TypeError(
"{} instance doesn't have valid length".format(type(self).__name__)
)
def __getstate__(self):
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(self)
if DILL_AVAILABLE:
dill_function = dill.dumps(self.fn)
else:
dill_function = self.fn
state = (
self.datapipe,
dill_function,
self.input_col,
self.output_col,
)
return state
def __setstate__(self, state):
(
self.datapipe,
dill_function,
self.input_col,
self.output_col,
) = state
if DILL_AVAILABLE:
self.fn = dill.loads(dill_function) # type: ignore[assignment]
else:
self.fn = dill_function # type: ignore[assignment]
@functional_datapipe("collate")
class CollatorIterDataPipe(MapperIterDataPipe):
r""":class:`CollatorIterDataPipe`.
Iterable DataPipe to collate samples from DataPipe to Tensor(s) by a custom collate function,
which defaults to `torch.utils.data.default_collate` if it is not specified.
.. note::
While writing a custom collate function, you can import `torch.utils.data.default_collate` for the
default behavior and `functools.partial` to specify any additional arguments.
Args:
datapipe: Iterable DataPipe being collated
collate_fn: Customized collate function to collect and combine data or a batch of data.
Default function collates to Tensor(s) based on data type.
Example: Convert integer data to float Tensor
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
... def __init__(self, start, end):
... super(MyIterDataPipe).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... return iter(range(self.start, self.end))
...
... def __len__(self):
... return self.end - self.start
...
>>> ds = MyIterDataPipe(start=3, end=7)
>>> print(list(ds))
[3, 4, 5, 6]
>>> def collate_fn(batch):
... return torch.tensor(batch, dtype=torch.float)
...
>>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
>>> print(list(collated_ds))
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
"""
def __init__(
self,
datapipe: IterDataPipe,
collate_fn: Callable = _utils.collate.default_collate,
) -> None:
super().__init__(datapipe, fn=collate_fn)