mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is the result of applying the ruff `UP035` check. `Callable` is imported from `collections.abc` instead of `typing`. `TypeAlias` is also imported from `typing`. This PR is the follow-up of #163947. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164054 Approved by: https://github.com/ezyang, https://github.com/Skylion007
244 lines
8.9 KiB
Python
244 lines
8.9 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
from collections import namedtuple
|
|
from collections.abc import Callable, Iterator, Sized
|
|
from typing import Any, Optional, TypeVar, Union
|
|
|
|
import torch
|
|
from torch.utils.data._utils.collate import default_collate
|
|
from torch.utils.data.datapipes._decorator import functional_datapipe
|
|
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
|
|
from torch.utils.data.datapipes.datapipe import IterDataPipe
|
|
from torch.utils.data.datapipes.utils.common import (
|
|
_check_unpickable_fn,
|
|
validate_input_col,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"CollatorIterDataPipe",
|
|
"MapperIterDataPipe",
|
|
]
|
|
|
|
|
|
_T_co = TypeVar("_T_co", covariant=True)
|
|
|
|
|
|
@functional_datapipe("map")
|
|
class MapperIterDataPipe(IterDataPipe[_T_co]):
|
|
r"""
|
|
Applies a function over each item from the source DataPipe (functional name: ``map``).
|
|
|
|
The function can be any regular Python function or partial object. Lambda
|
|
function is not recommended as it is not supported by pickle.
|
|
|
|
Args:
|
|
datapipe: Source Iterable DataPipe
|
|
fn: Function being applied over each item
|
|
input_col: Index or indices of data which ``fn`` is applied, such as:
|
|
|
|
- ``None`` as default to apply ``fn`` to the data directly.
|
|
- Integer(s) is used for list/tuple.
|
|
- Key(s) is used for dict.
|
|
|
|
output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified
|
|
only when ``input_col`` is not ``None``
|
|
|
|
- ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with
|
|
multiple indices, the left-most one is used, and other indices will be removed.
|
|
- Integer is used for list/tuple. ``-1`` represents to append result at the end.
|
|
- Key is used for dict. New key is acceptable.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper
|
|
>>> def add_one(x):
|
|
... return x + 1
|
|
>>> dp = IterableWrapper(range(10))
|
|
>>> # Invocation via functional form is preferred
|
|
... map_dp_1 = dp.map(add_one)
|
|
>>> list(map_dp_1)
|
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
|
>>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
|
|
>>> # Use `functools.partial` or explicitly define the function instead
|
|
>>> map_dp_2 = Mapper(dp, lambda x: x + 1)
|
|
>>> list(map_dp_2)
|
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
|
"""
|
|
|
|
datapipe: IterDataPipe
|
|
fn: Callable
|
|
|
|
def __init__(
|
|
self,
|
|
datapipe: IterDataPipe,
|
|
fn: Callable,
|
|
input_col=None,
|
|
output_col=None,
|
|
) -> None:
|
|
torch._C._log_api_usage_once("python.data_pipes.map")
|
|
super().__init__()
|
|
self.datapipe = datapipe
|
|
|
|
_check_unpickable_fn(fn)
|
|
self.fn = fn # type: ignore[assignment]
|
|
|
|
self.input_col = input_col
|
|
if input_col is None and output_col is not None:
|
|
raise ValueError("`output_col` must be None when `input_col` is None.")
|
|
if isinstance(output_col, (list, tuple)):
|
|
if len(output_col) > 1:
|
|
raise ValueError("`output_col` must be a single-element list or tuple")
|
|
output_col = output_col[0]
|
|
self.output_col = output_col
|
|
validate_input_col(fn, input_col)
|
|
|
|
def _apply_fn(self, data):
|
|
if self.input_col is None and self.output_col is None:
|
|
return self.fn(data)
|
|
|
|
if self.input_col is None:
|
|
res = self.fn(data)
|
|
elif isinstance(self.input_col, (list, tuple)):
|
|
args = tuple(data[col] for col in self.input_col)
|
|
res = self.fn(*args)
|
|
else:
|
|
res = self.fn(data[self.input_col])
|
|
|
|
# Copy tuple to list and run in-place modification because tuple is immutable.
|
|
if isinstance(data, tuple):
|
|
t_flag = True
|
|
data = list(data)
|
|
else:
|
|
t_flag = False
|
|
|
|
if self.output_col is None:
|
|
if isinstance(self.input_col, (list, tuple)):
|
|
data[self.input_col[0]] = res
|
|
for idx in sorted(self.input_col[1:], reverse=True):
|
|
del data[idx]
|
|
else:
|
|
data[self.input_col] = res
|
|
else:
|
|
if self.output_col == -1:
|
|
data.append(res)
|
|
else:
|
|
data[self.output_col] = res
|
|
|
|
# Convert list back to tuple
|
|
return tuple(data) if t_flag else data
|
|
|
|
def __iter__(self) -> Iterator[_T_co]:
|
|
for data in self.datapipe:
|
|
yield self._apply_fn(data)
|
|
|
|
def __len__(self) -> int:
|
|
if isinstance(self.datapipe, Sized):
|
|
return len(self.datapipe)
|
|
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
|
|
|
|
|
def _collate_helper(conversion, item):
|
|
# TODO(VitalyFedyunin): Verify that item is any sort of batch
|
|
if len(item.items) > 1:
|
|
# TODO(VitalyFedyunin): Compact all batch dataframes into one
|
|
raise RuntimeError("Only supports one DataFrame per batch")
|
|
df = item[0]
|
|
columns_name = df_wrapper.get_columns(df)
|
|
tuple_names: list = []
|
|
tuple_values: list = []
|
|
|
|
for name in conversion.keys():
|
|
if name not in columns_name:
|
|
raise RuntimeError("Conversion keys mismatch")
|
|
|
|
for name in columns_name:
|
|
if name in conversion:
|
|
if not callable(conversion[name]):
|
|
raise RuntimeError(
|
|
"Collate (DF)DataPipe requires callable as dict values"
|
|
)
|
|
collation_fn = conversion[name]
|
|
else:
|
|
# TODO(VitalyFedyunin): Add default collation into df_wrapper
|
|
try:
|
|
import torcharrow.pytorch as tap # type: ignore[import]
|
|
|
|
collation_fn = tap.rec.Default()
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
"unable to import default collation function from the TorchArrow"
|
|
) from e
|
|
|
|
tuple_names.append(str(name))
|
|
value = collation_fn(df[name])
|
|
tuple_values.append(value)
|
|
|
|
# TODO(VitalyFedyunin): We can dynamically extract types from the tuple_values here
|
|
# TODO(VitalyFedyunin): Instead of ignoring mypy error, make sure tuple_names is not empty
|
|
tpl_cls = namedtuple("CollateResult", tuple_names) # type: ignore[misc]
|
|
tuple = tpl_cls(*tuple_values)
|
|
return tuple
|
|
|
|
|
|
@functional_datapipe("collate")
|
|
class CollatorIterDataPipe(MapperIterDataPipe):
|
|
r"""
|
|
Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``).
|
|
|
|
By default, it uses :func:`torch.utils.data.default_collate`.
|
|
|
|
.. note::
|
|
While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the
|
|
default behavior and `functools.partial` to specify any additional arguments.
|
|
|
|
Args:
|
|
datapipe: Iterable DataPipe being collated
|
|
collate_fn: Customized collate function to collect and combine data or a batch of data.
|
|
Default function collates to Tensor(s) based on data type.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> # Convert integer data to float Tensor
|
|
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
|
|
... def __init__(self, start, end):
|
|
... super(MyIterDataPipe).__init__()
|
|
... assert end > start, "this example only works with end >= start"
|
|
... self.start = start
|
|
... self.end = end
|
|
...
|
|
... def __iter__(self):
|
|
... return iter(range(self.start, self.end))
|
|
...
|
|
... def __len__(self):
|
|
... return self.end - self.start
|
|
>>> ds = MyIterDataPipe(start=3, end=7)
|
|
>>> print(list(ds))
|
|
[3, 4, 5, 6]
|
|
>>> def collate_fn(batch):
|
|
... return torch.tensor(batch, dtype=torch.float)
|
|
>>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
|
|
>>> print(list(collated_ds))
|
|
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
datapipe: IterDataPipe,
|
|
conversion: Union[
|
|
Callable[..., Any], dict[Union[str, Any], Union[Callable, Any]], None
|
|
] = default_collate,
|
|
collate_fn: Optional[Callable] = None,
|
|
) -> None:
|
|
# TODO(VitalyFedyunin): Replace `Callable[..., Any]` with `Callable[[IColumn], Any]`
|
|
# TODO(VitalyFedyunin): Replace with `Dict[Union[str, IColumn], Union[Callable, Enum]]`
|
|
if collate_fn is not None:
|
|
super().__init__(datapipe, fn=collate_fn)
|
|
else:
|
|
if callable(conversion):
|
|
super().__init__(datapipe, fn=conversion)
|
|
else:
|
|
# TODO(VitalyFedyunin): Validate passed dictionary
|
|
collate_fn = functools.partial(_collate_helper, conversion)
|
|
super().__init__(datapipe, fn=collate_fn)
|