mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/144552 Approved by: https://github.com/ezyang
		
			
				
	
	
		
			244 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			244 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # mypy: allow-untyped-defs
 | |
| import functools
 | |
| from collections import namedtuple
 | |
| from collections.abc import Iterator, Sized
 | |
| from typing import Any, Callable, Optional, TypeVar, Union
 | |
| 
 | |
| import torch
 | |
| from torch.utils.data._utils.collate import default_collate
 | |
| from torch.utils.data.datapipes._decorator import functional_datapipe
 | |
| from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
 | |
| from torch.utils.data.datapipes.datapipe import IterDataPipe
 | |
| from torch.utils.data.datapipes.utils.common import (
 | |
|     _check_unpickable_fn,
 | |
|     validate_input_col,
 | |
| )
 | |
| 
 | |
| 
 | |
| __all__ = [
 | |
|     "CollatorIterDataPipe",
 | |
|     "MapperIterDataPipe",
 | |
| ]
 | |
| 
 | |
| 
 | |
| _T_co = TypeVar("_T_co", covariant=True)
 | |
| 
 | |
| 
 | |
| @functional_datapipe("map")
 | |
| class MapperIterDataPipe(IterDataPipe[_T_co]):
 | |
|     r"""
 | |
|     Applies a function over each item from the source DataPipe (functional name: ``map``).
 | |
| 
 | |
|     The function can be any regular Python function or partial object. Lambda
 | |
|     function is not recommended as it is not supported by pickle.
 | |
| 
 | |
|     Args:
 | |
|         datapipe: Source Iterable DataPipe
 | |
|         fn: Function being applied over each item
 | |
|         input_col: Index or indices of data which ``fn`` is applied, such as:
 | |
| 
 | |
|             - ``None`` as default to apply ``fn`` to the data directly.
 | |
|             - Integer(s) is used for list/tuple.
 | |
|             - Key(s) is used for dict.
 | |
| 
 | |
|         output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified
 | |
|             only when ``input_col`` is not ``None``
 | |
| 
 | |
|             - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with
 | |
|               multiple indices, the left-most one is used, and other indices will be removed.
 | |
|             - Integer is used for list/tuple. ``-1`` represents to append result at the end.
 | |
|             - Key is used for dict. New key is acceptable.
 | |
| 
 | |
|     Example:
 | |
|         >>> # xdoctest: +SKIP
 | |
|         >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
 | |
|         >>> def add_one(x):
 | |
|         ...     return x + 1
 | |
|         >>> dp = IterableWrapper(range(10))
 | |
|         >>> # Invocation via functional form is preferred
 | |
|         ... map_dp_1 = dp.map(add_one)
 | |
|         >>> list(map_dp_1)
 | |
|         [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
 | |
|         >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
 | |
|         >>> # Use `functools.partial` or explicitly define the function instead
 | |
|         >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
 | |
|         >>> list(map_dp_2)
 | |
|         [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
 | |
|     """
 | |
| 
 | |
|     datapipe: IterDataPipe
 | |
|     fn: Callable
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         datapipe: IterDataPipe,
 | |
|         fn: Callable,
 | |
|         input_col=None,
 | |
|         output_col=None,
 | |
|     ) -> None:
 | |
|         torch._C._log_api_usage_once("python.data_pipes.map")
 | |
|         super().__init__()
 | |
|         self.datapipe = datapipe
 | |
| 
 | |
|         _check_unpickable_fn(fn)
 | |
|         self.fn = fn  # type: ignore[assignment]
 | |
| 
 | |
|         self.input_col = input_col
 | |
|         if input_col is None and output_col is not None:
 | |
|             raise ValueError("`output_col` must be None when `input_col` is None.")
 | |
|         if isinstance(output_col, (list, tuple)):
 | |
|             if len(output_col) > 1:
 | |
|                 raise ValueError("`output_col` must be a single-element list or tuple")
 | |
|             output_col = output_col[0]
 | |
|         self.output_col = output_col
 | |
|         validate_input_col(fn, input_col)
 | |
| 
 | |
|     def _apply_fn(self, data):
 | |
|         if self.input_col is None and self.output_col is None:
 | |
|             return self.fn(data)
 | |
| 
 | |
|         if self.input_col is None:
 | |
|             res = self.fn(data)
 | |
|         elif isinstance(self.input_col, (list, tuple)):
 | |
|             args = tuple(data[col] for col in self.input_col)
 | |
|             res = self.fn(*args)
 | |
|         else:
 | |
|             res = self.fn(data[self.input_col])
 | |
| 
 | |
|         # Copy tuple to list and run in-place modification because tuple is immutable.
 | |
|         if isinstance(data, tuple):
 | |
|             t_flag = True
 | |
|             data = list(data)
 | |
|         else:
 | |
|             t_flag = False
 | |
| 
 | |
|         if self.output_col is None:
 | |
|             if isinstance(self.input_col, (list, tuple)):
 | |
|                 data[self.input_col[0]] = res
 | |
|                 for idx in sorted(self.input_col[1:], reverse=True):
 | |
|                     del data[idx]
 | |
|             else:
 | |
|                 data[self.input_col] = res
 | |
|         else:
 | |
|             if self.output_col == -1:
 | |
|                 data.append(res)
 | |
|             else:
 | |
|                 data[self.output_col] = res
 | |
| 
 | |
|         # Convert list back to tuple
 | |
|         return tuple(data) if t_flag else data
 | |
| 
 | |
|     def __iter__(self) -> Iterator[_T_co]:
 | |
|         for data in self.datapipe:
 | |
|             yield self._apply_fn(data)
 | |
| 
 | |
|     def __len__(self) -> int:
 | |
|         if isinstance(self.datapipe, Sized):
 | |
|             return len(self.datapipe)
 | |
|         raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
 | |
| 
 | |
| 
 | |
| def _collate_helper(conversion, item):
 | |
|     # TODO(VitalyFedyunin): Verify that item is any sort of batch
 | |
|     if len(item.items) > 1:
 | |
|         # TODO(VitalyFedyunin): Compact all batch dataframes into one
 | |
|         raise RuntimeError("Only supports one DataFrame per batch")
 | |
|     df = item[0]
 | |
|     columns_name = df_wrapper.get_columns(df)
 | |
|     tuple_names: list = []
 | |
|     tuple_values: list = []
 | |
| 
 | |
|     for name in conversion.keys():
 | |
|         if name not in columns_name:
 | |
|             raise RuntimeError("Conversion keys mismatch")
 | |
| 
 | |
|     for name in columns_name:
 | |
|         if name in conversion:
 | |
|             if not callable(conversion[name]):
 | |
|                 raise RuntimeError(
 | |
|                     "Collate (DF)DataPipe requires callable as dict values"
 | |
|                 )
 | |
|             collation_fn = conversion[name]
 | |
|         else:
 | |
|             # TODO(VitalyFedyunin): Add default collation into df_wrapper
 | |
|             try:
 | |
|                 import torcharrow.pytorch as tap  # type: ignore[import]
 | |
| 
 | |
|                 collation_fn = tap.rec.Default()
 | |
|             except Exception as e:
 | |
|                 raise RuntimeError(
 | |
|                     "unable to import default collation function from the TorchArrow"
 | |
|                 ) from e
 | |
| 
 | |
|         tuple_names.append(str(name))
 | |
|         value = collation_fn(df[name])
 | |
|         tuple_values.append(value)
 | |
| 
 | |
|     # TODO(VitalyFedyunin): We can dynamically extract types from the tuple_values here
 | |
|     # TODO(VitalyFedyunin): Instead of ignoring mypy error, make sure tuple_names is not empty
 | |
|     tpl_cls = namedtuple("CollateResult", tuple_names)  # type: ignore[misc]
 | |
|     tuple = tpl_cls(*tuple_values)
 | |
|     return tuple
 | |
| 
 | |
| 
 | |
| @functional_datapipe("collate")
 | |
| class CollatorIterDataPipe(MapperIterDataPipe):
 | |
|     r"""
 | |
|     Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``).
 | |
| 
 | |
|     By default, it uses :func:`torch.utils.data.default_collate`.
 | |
| 
 | |
|     .. note::
 | |
|         While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the
 | |
|         default behavior and `functools.partial` to specify any additional arguments.
 | |
| 
 | |
|     Args:
 | |
|         datapipe: Iterable DataPipe being collated
 | |
|         collate_fn: Customized collate function to collect and combine data or a batch of data.
 | |
|             Default function collates to Tensor(s) based on data type.
 | |
| 
 | |
|     Example:
 | |
|         >>> # xdoctest: +SKIP
 | |
|         >>> # Convert integer data to float Tensor
 | |
|         >>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
 | |
|         ...     def __init__(self, start, end):
 | |
|         ...         super(MyIterDataPipe).__init__()
 | |
|         ...         assert end > start, "this example only works with end >= start"
 | |
|         ...         self.start = start
 | |
|         ...         self.end = end
 | |
|         ...
 | |
|         ...     def __iter__(self):
 | |
|         ...         return iter(range(self.start, self.end))
 | |
|         ...
 | |
|         ...     def __len__(self):
 | |
|         ...         return self.end - self.start
 | |
|         >>> ds = MyIterDataPipe(start=3, end=7)
 | |
|         >>> print(list(ds))
 | |
|         [3, 4, 5, 6]
 | |
|         >>> def collate_fn(batch):
 | |
|         ...     return torch.tensor(batch, dtype=torch.float)
 | |
|         >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
 | |
|         >>> print(list(collated_ds))
 | |
|         [tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         datapipe: IterDataPipe,
 | |
|         conversion: Union[
 | |
|             Callable[..., Any], dict[Union[str, Any], Union[Callable, Any]], None
 | |
|         ] = default_collate,
 | |
|         collate_fn: Optional[Callable] = None,
 | |
|     ) -> None:
 | |
|         # TODO(VitalyFedyunin): Replace `Callable[..., Any]` with `Callable[[IColumn], Any]`
 | |
|         # TODO(VitalyFedyunin): Replace with `Dict[Union[str, IColumn], Union[Callable, Enum]]`
 | |
|         if collate_fn is not None:
 | |
|             super().__init__(datapipe, fn=collate_fn)
 | |
|         else:
 | |
|             if callable(conversion):
 | |
|                 super().__init__(datapipe, fn=conversion)
 | |
|             else:
 | |
|                 # TODO(VitalyFedyunin): Validate passed dictionary
 | |
|                 collate_fn = functools.partial(_collate_helper, conversion)
 | |
|                 super().__init__(datapipe, fn=collate_fn)
 |