mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is the result of applying the ruff `UP035` check. `Callable` is imported from `collections.abc` instead of `typing`. `TypeAlias` is also imported from `typing`. This PR is the follow-up of #163947. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164054 Approved by: https://github.com/ezyang, https://github.com/Skylion007
103 lines
3.2 KiB
Python
103 lines
3.2 KiB
Python
# mypy: allow-untyped-defs
|
|
from collections.abc import Callable, Iterator
|
|
from typing import TypeVar
|
|
|
|
from torch.utils.data.datapipes._decorator import functional_datapipe
|
|
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
|
|
from torch.utils.data.datapipes.datapipe import IterDataPipe
|
|
from torch.utils.data.datapipes.utils.common import (
|
|
_check_unpickable_fn,
|
|
StreamWrapper,
|
|
validate_input_col,
|
|
)
|
|
|
|
|
|
__all__ = ["FilterIterDataPipe"]
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
_T_co = TypeVar("_T_co", covariant=True)
|
|
|
|
|
|
@functional_datapipe("filter")
|
|
class FilterIterDataPipe(IterDataPipe[_T_co]):
|
|
r"""
|
|
Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``).
|
|
|
|
Args:
|
|
datapipe: Iterable DataPipe being filtered
|
|
filter_fn: Customized function mapping an element to a boolean.
|
|
input_col: Index or indices of data which ``filter_fn`` is applied, such as:
|
|
|
|
- ``None`` as default to apply ``filter_fn`` to the data directly.
|
|
- Integer(s) is used for list/tuple.
|
|
- Key(s) is used for dict.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> from torchdata.datapipes.iter import IterableWrapper
|
|
>>> def is_even(n):
|
|
... return n % 2 == 0
|
|
>>> dp = IterableWrapper(range(5))
|
|
>>> filter_dp = dp.filter(filter_fn=is_even)
|
|
>>> list(filter_dp)
|
|
[0, 2, 4]
|
|
"""
|
|
|
|
datapipe: IterDataPipe[_T_co]
|
|
filter_fn: Callable
|
|
|
|
def __init__(
|
|
self,
|
|
datapipe: IterDataPipe[_T_co],
|
|
filter_fn: Callable,
|
|
input_col=None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.datapipe = datapipe
|
|
|
|
_check_unpickable_fn(filter_fn)
|
|
self.filter_fn = filter_fn # type: ignore[assignment]
|
|
|
|
self.input_col = input_col
|
|
validate_input_col(filter_fn, input_col)
|
|
|
|
def _apply_filter_fn(self, data) -> bool:
|
|
if self.input_col is None:
|
|
return self.filter_fn(data)
|
|
elif isinstance(self.input_col, (list, tuple)):
|
|
args = tuple(data[col] for col in self.input_col)
|
|
return self.filter_fn(*args)
|
|
else:
|
|
return self.filter_fn(data[self.input_col])
|
|
|
|
def __iter__(self) -> Iterator[_T_co]:
|
|
for data in self.datapipe:
|
|
condition, filtered = self._returnIfTrue(data)
|
|
if condition:
|
|
yield filtered
|
|
else:
|
|
StreamWrapper.close_streams(data)
|
|
|
|
def _returnIfTrue(self, data: _T) -> tuple[bool, _T]:
|
|
condition = self._apply_filter_fn(data)
|
|
|
|
if df_wrapper.is_column(condition):
|
|
# We are operating on DataFrames filter here
|
|
result = []
|
|
for idx, mask in enumerate(df_wrapper.iterate(condition)):
|
|
if mask:
|
|
result.append(df_wrapper.get_item(data, idx))
|
|
if len(result):
|
|
return True, df_wrapper.concat(result)
|
|
else:
|
|
return False, None # type: ignore[return-value]
|
|
|
|
if not isinstance(condition, bool):
|
|
raise ValueError(
|
|
"Boolean output is required for `filter_fn` of FilterIterDataPipe, got",
|
|
type(condition),
|
|
)
|
|
|
|
return condition, data
|