mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is the result of applying the ruff `UP035` check. `Callable` is imported from `collections.abc` instead of `typing`. `TypeAlias` is also imported from `typing`. This PR is the follow-up of #163947. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164054 Approved by: https://github.com/ezyang, https://github.com/Skylion007
71 lines
2.7 KiB
Python
71 lines
2.7 KiB
Python
from collections.abc import Callable, Iterable, Iterator, Sized
|
|
from io import BufferedIOBase
|
|
from typing import Any
|
|
|
|
from torch.utils.data.datapipes._decorator import functional_datapipe
|
|
from torch.utils.data.datapipes.datapipe import IterDataPipe
|
|
from torch.utils.data.datapipes.utils.common import _deprecation_warning
|
|
from torch.utils.data.datapipes.utils.decoder import (
|
|
basichandlers as decoder_basichandlers,
|
|
Decoder,
|
|
extension_extract_fn,
|
|
imagehandler as decoder_imagehandler,
|
|
)
|
|
|
|
|
|
__all__ = ["RoutedDecoderIterDataPipe"]
|
|
|
|
|
|
@functional_datapipe("routed_decode")
|
|
class RoutedDecoderIterDataPipe(IterDataPipe[tuple[str, Any]]):
|
|
r"""
|
|
Decodes binary streams from input DataPipe, yields pathname and decoded data in a tuple.
|
|
|
|
(functional name: ``routed_decode``)
|
|
|
|
Args:
|
|
datapipe: Iterable datapipe that provides pathname and binary stream in tuples
|
|
handlers: Optional user defined decoder handlers. If ``None``, basic and image decoder
|
|
handlers will be set as default. If multiple handles are provided, the priority
|
|
order follows the order of handlers (the first handler has the top priority)
|
|
key_fn: Function for decoder to extract key from pathname to dispatch handlers.
|
|
Default is set to extract file extension from pathname
|
|
|
|
Note:
|
|
When ``key_fn`` is specified returning anything other than extension, the default
|
|
handler will not work and users need to specify custom handler. Custom handler
|
|
could use regex to determine the eligibility to handle data.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
datapipe: Iterable[tuple[str, BufferedIOBase]],
|
|
*handlers: Callable,
|
|
key_fn: Callable = extension_extract_fn,
|
|
) -> None:
|
|
super().__init__()
|
|
self.datapipe: Iterable[tuple[str, BufferedIOBase]] = datapipe
|
|
if not handlers:
|
|
handlers = (decoder_basichandlers, decoder_imagehandler("torch"))
|
|
self.decoder = Decoder(*handlers, key_fn=key_fn)
|
|
_deprecation_warning(
|
|
type(self).__name__,
|
|
deprecation_version="1.12",
|
|
removal_version="1.13",
|
|
old_functional_name="routed_decode",
|
|
)
|
|
|
|
def add_handler(self, *handler: Callable) -> None:
|
|
self.decoder.add_handler(*handler)
|
|
|
|
def __iter__(self) -> Iterator[tuple[str, Any]]:
|
|
for data in self.datapipe:
|
|
pathname = data[0]
|
|
result = self.decoder(data)
|
|
yield (pathname, result[pathname])
|
|
|
|
def __len__(self) -> int:
|
|
if isinstance(self.datapipe, Sized):
|
|
return len(self.datapipe)
|
|
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|