mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is the result of applying the ruff `UP035` check. `Callable` is imported from `collections.abc` instead of `typing`. `TypeAlias` is also imported from `typing`. This PR is the follow-up of #163947. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164054 Approved by: https://github.com/ezyang, https://github.com/Skylion007
400 lines
16 KiB
Python
400 lines
16 KiB
Python
# mypy: allow-untyped-defs
|
|
r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
|
|
|
|
These methods are used to collate samples fetched from dataset into Tensor(s).
|
|
These **needs** to be in global scope since Py2 doesn't support serializing
|
|
static methods.
|
|
|
|
`default_collate` and `default_convert` are exposed to users via 'dataloader.py'.
|
|
"""
|
|
|
|
import collections
|
|
import contextlib
|
|
import copy
|
|
import re
|
|
from collections.abc import Callable
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
|
|
|
|
np_str_obj_array_pattern = re.compile(r"[SaUO]")
|
|
|
|
|
|
def default_convert(data):
|
|
r"""
|
|
Convert each NumPy array element into a :class:`torch.Tensor`.
|
|
|
|
If the input is a `Sequence`, `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`.
|
|
If the input is not an NumPy array, it is left unchanged.
|
|
This is used as the default function for collation when both `batch_sampler` and `batch_size`
|
|
are NOT defined in :class:`~torch.utils.data.DataLoader`.
|
|
|
|
The general input type to output type mapping is similar to that
|
|
of :func:`~torch.utils.data.default_collate`. See the description there for more details.
|
|
|
|
Args:
|
|
data: a single data point to be converted
|
|
|
|
Examples:
|
|
>>> # xdoctest: +SKIP
|
|
>>> # Example with `int`
|
|
>>> default_convert(0)
|
|
0
|
|
>>> # Example with NumPy array
|
|
>>> default_convert(np.array([0, 1]))
|
|
tensor([0, 1])
|
|
>>> # Example with NamedTuple
|
|
>>> Point = namedtuple("Point", ["x", "y"])
|
|
>>> default_convert(Point(0, 0))
|
|
Point(x=0, y=0)
|
|
>>> default_convert(Point(np.array(0), np.array(0)))
|
|
Point(x=tensor(0), y=tensor(0))
|
|
>>> # Example with List
|
|
>>> default_convert([np.array([0, 1]), np.array([2, 3])])
|
|
[tensor([0, 1]), tensor([2, 3])]
|
|
"""
|
|
elem_type = type(data)
|
|
if isinstance(data, torch.Tensor):
|
|
return data
|
|
elif (
|
|
elem_type.__module__ == "numpy"
|
|
and elem_type.__name__ != "str_"
|
|
and elem_type.__name__ != "string_"
|
|
):
|
|
# array of string classes and object
|
|
if (
|
|
elem_type.__name__ == "ndarray"
|
|
and np_str_obj_array_pattern.search(data.dtype.str) is not None
|
|
):
|
|
return data
|
|
return torch.as_tensor(data)
|
|
elif isinstance(data, collections.abc.Mapping):
|
|
try:
|
|
if isinstance(data, collections.abc.MutableMapping):
|
|
# The mapping type may have extra properties, so we can't just
|
|
# use `type(data)(...)` to create the new mapping.
|
|
# Create a clone and update it if the mapping type is mutable.
|
|
clone = copy.copy(data)
|
|
clone.update({key: default_convert(data[key]) for key in data})
|
|
return clone
|
|
else:
|
|
return elem_type({key: default_convert(data[key]) for key in data})
|
|
except TypeError:
|
|
# The mapping type may not support `copy()` / `update(mapping)`
|
|
# or `__init__(iterable)`.
|
|
return {key: default_convert(data[key]) for key in data}
|
|
elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
|
|
return elem_type(*(default_convert(d) for d in data))
|
|
elif isinstance(data, tuple):
|
|
return [default_convert(d) for d in data] # Backwards compatibility.
|
|
elif isinstance(data, collections.abc.Sequence) and not isinstance(
|
|
data, (str, bytes)
|
|
):
|
|
try:
|
|
if isinstance(data, collections.abc.MutableSequence):
|
|
# The sequence type may have extra properties, so we can't just
|
|
# use `type(data)(...)` to create the new sequence.
|
|
# Create a clone and update it if the sequence type is mutable.
|
|
clone = copy.copy(data) # type: ignore[arg-type]
|
|
for i, d in enumerate(data):
|
|
clone[i] = default_convert(d)
|
|
return clone
|
|
else:
|
|
return elem_type([default_convert(d) for d in data])
|
|
except TypeError:
|
|
# The sequence type may not support `copy()` / `__setitem__(index, item)`
|
|
# or `__init__(iterable)` (e.g., `range`).
|
|
return [default_convert(d) for d in data]
|
|
else:
|
|
return data
|
|
|
|
|
|
default_collate_err_msg_format = (
|
|
"default_collate: batch must contain tensors, numpy arrays, numbers, "
|
|
"dicts or lists; found {}"
|
|
)
|
|
|
|
|
|
def collate(
|
|
batch,
|
|
*,
|
|
collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
|
|
):
|
|
r"""
|
|
General collate function that handles collection type of element within each batch.
|
|
|
|
The function also opens function registry to deal with specific element types. `default_collate_fn_map`
|
|
provides default collate functions for tensors, numpy arrays, numbers and strings.
|
|
|
|
Args:
|
|
batch: a single batch to be collated
|
|
collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function.
|
|
If the element type isn't present in this dictionary,
|
|
this function will go through each key of the dictionary in the insertion order to
|
|
invoke the corresponding collate function if the element type is a subclass of the key.
|
|
|
|
Examples:
|
|
>>> def collate_tensor_fn(batch, *, collate_fn_map):
|
|
... # Extend this function to handle batch of tensors
|
|
... return torch.stack(batch, 0)
|
|
>>> def custom_collate(batch):
|
|
... collate_map = {torch.Tensor: collate_tensor_fn}
|
|
... return collate(batch, collate_fn_map=collate_map)
|
|
>>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
|
|
>>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})
|
|
|
|
Note:
|
|
Each collate function requires a positional argument for batch and a keyword argument
|
|
for the dictionary of collate functions as `collate_fn_map`.
|
|
"""
|
|
elem = batch[0]
|
|
elem_type = type(elem)
|
|
|
|
if collate_fn_map is not None:
|
|
if elem_type in collate_fn_map:
|
|
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
|
|
|
|
for collate_type in collate_fn_map:
|
|
if isinstance(elem, collate_type):
|
|
return collate_fn_map[collate_type](
|
|
batch, collate_fn_map=collate_fn_map
|
|
)
|
|
|
|
if isinstance(elem, collections.abc.Mapping):
|
|
try:
|
|
if isinstance(elem, collections.abc.MutableMapping):
|
|
# The mapping type may have extra properties, so we can't just
|
|
# use `type(data)(...)` to create the new mapping.
|
|
# Create a clone and update it if the mapping type is mutable.
|
|
clone = copy.copy(elem)
|
|
clone.update(
|
|
{
|
|
key: collate(
|
|
[d[key] for d in batch], collate_fn_map=collate_fn_map
|
|
)
|
|
for key in elem
|
|
}
|
|
)
|
|
return clone
|
|
else:
|
|
return elem_type(
|
|
{
|
|
key: collate(
|
|
[d[key] for d in batch], collate_fn_map=collate_fn_map
|
|
)
|
|
for key in elem
|
|
}
|
|
)
|
|
except TypeError:
|
|
# The mapping type may not support `copy()` / `update(mapping)`
|
|
# or `__init__(iterable)`.
|
|
return {
|
|
key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map)
|
|
for key in elem
|
|
}
|
|
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
|
|
return elem_type(
|
|
*(
|
|
collate(samples, collate_fn_map=collate_fn_map)
|
|
for samples in zip(*batch)
|
|
)
|
|
)
|
|
elif isinstance(elem, collections.abc.Sequence):
|
|
# check to make sure that the elements in batch have consistent size
|
|
it = iter(batch)
|
|
elem_size = len(next(it))
|
|
if not all(len(elem) == elem_size for elem in it):
|
|
raise RuntimeError("each element in list of batch should be of equal size")
|
|
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
|
|
|
|
if isinstance(elem, tuple):
|
|
return [
|
|
collate(samples, collate_fn_map=collate_fn_map)
|
|
for samples in transposed
|
|
] # Backwards compatibility.
|
|
else:
|
|
try:
|
|
if isinstance(elem, collections.abc.MutableSequence):
|
|
# The sequence type may have extra properties, so we can't just
|
|
# use `type(data)(...)` to create the new sequence.
|
|
# Create a clone and update it if the sequence type is mutable.
|
|
clone = copy.copy(elem) # type: ignore[arg-type]
|
|
for i, samples in enumerate(transposed):
|
|
clone[i] = collate(samples, collate_fn_map=collate_fn_map)
|
|
return clone
|
|
else:
|
|
return elem_type(
|
|
[
|
|
collate(samples, collate_fn_map=collate_fn_map)
|
|
for samples in transposed
|
|
]
|
|
)
|
|
except TypeError:
|
|
# The sequence type may not support `copy()` / `__setitem__(index, item)`
|
|
# or `__init__(iterable)` (e.g., `range`).
|
|
return [
|
|
collate(samples, collate_fn_map=collate_fn_map)
|
|
for samples in transposed
|
|
]
|
|
|
|
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
|
|
|
|
|
def collate_tensor_fn(
|
|
batch,
|
|
*,
|
|
collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
|
|
):
|
|
elem = batch[0]
|
|
out = None
|
|
if elem.is_nested:
|
|
raise RuntimeError(
|
|
"Batches of nested tensors are not currently supported by the default collate_fn; "
|
|
"please provide a custom collate_fn to handle them appropriately."
|
|
)
|
|
if elem.layout in {
|
|
torch.sparse_coo,
|
|
torch.sparse_csr,
|
|
torch.sparse_bsr,
|
|
torch.sparse_csc,
|
|
torch.sparse_bsc,
|
|
}:
|
|
raise RuntimeError(
|
|
"Batches of sparse tensors are not currently supported by the default collate_fn; "
|
|
"please provide a custom collate_fn to handle them appropriately."
|
|
)
|
|
if torch.utils.data.get_worker_info() is not None:
|
|
# If we're in a background process, concatenate directly into a
|
|
# shared memory tensor to avoid an extra copy
|
|
numel = sum(x.numel() for x in batch)
|
|
storage = elem._typed_storage()._new_shared(numel, device=elem.device)
|
|
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
|
|
return torch.stack(batch, 0, out=out)
|
|
|
|
|
|
def collate_numpy_array_fn(
|
|
batch,
|
|
*,
|
|
collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
|
|
):
|
|
elem = batch[0]
|
|
# array of string classes and object
|
|
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
|
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
|
|
|
return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)
|
|
|
|
|
|
def collate_numpy_scalar_fn(
|
|
batch,
|
|
*,
|
|
collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
|
|
):
|
|
return torch.as_tensor(batch)
|
|
|
|
|
|
def collate_float_fn(
|
|
batch,
|
|
*,
|
|
collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
|
|
):
|
|
return torch.tensor(batch, dtype=torch.float64)
|
|
|
|
|
|
def collate_int_fn(
|
|
batch,
|
|
*,
|
|
collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
|
|
):
|
|
return torch.tensor(batch)
|
|
|
|
|
|
def collate_str_fn(
|
|
batch,
|
|
*,
|
|
collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
|
|
):
|
|
return batch
|
|
|
|
|
|
default_collate_fn_map: dict[Union[type, tuple[type, ...]], Callable] = {
|
|
torch.Tensor: collate_tensor_fn
|
|
}
|
|
with contextlib.suppress(ImportError):
|
|
import numpy as np
|
|
|
|
# For both ndarray and memmap (subclass of ndarray)
|
|
default_collate_fn_map[np.ndarray] = collate_numpy_array_fn
|
|
# See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html
|
|
# Skip string scalars
|
|
default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn
|
|
default_collate_fn_map[float] = collate_float_fn
|
|
default_collate_fn_map[int] = collate_int_fn
|
|
default_collate_fn_map[str] = collate_str_fn
|
|
default_collate_fn_map[bytes] = collate_str_fn
|
|
|
|
|
|
def default_collate(batch):
|
|
r"""
|
|
Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.
|
|
|
|
The exact output type can be a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a
|
|
Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type.
|
|
This is used as the default function for collation when
|
|
`batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`.
|
|
|
|
Here is the general input type (based on the type of the element within the batch) to output type mapping:
|
|
|
|
* :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
|
|
* NumPy Arrays -> :class:`torch.Tensor`
|
|
* `float` -> :class:`torch.Tensor`
|
|
* `int` -> :class:`torch.Tensor`
|
|
* `str` -> `str` (unchanged)
|
|
* `bytes` -> `bytes` (unchanged)
|
|
* `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
|
|
* `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]),
|
|
default_collate([V2_1, V2_2, ...]), ...]`
|
|
* `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]),
|
|
default_collate([V2_1, V2_2, ...]), ...]`
|
|
|
|
Args:
|
|
batch: a single batch to be collated
|
|
|
|
Examples:
|
|
>>> # xdoctest: +SKIP
|
|
>>> # Example with a batch of `int`s:
|
|
>>> default_collate([0, 1, 2, 3])
|
|
tensor([0, 1, 2, 3])
|
|
>>> # Example with a batch of `str`s:
|
|
>>> default_collate(["a", "b", "c"])
|
|
['a', 'b', 'c']
|
|
>>> # Example with `Map` inside the batch:
|
|
>>> default_collate([{"A": 0, "B": 1}, {"A": 100, "B": 100}])
|
|
{'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
|
|
>>> # Example with `NamedTuple` inside the batch:
|
|
>>> Point = namedtuple("Point", ["x", "y"])
|
|
>>> default_collate([Point(0, 0), Point(1, 1)])
|
|
Point(x=tensor([0, 1]), y=tensor([0, 1]))
|
|
>>> # Example with `Tuple` inside the batch:
|
|
>>> default_collate([(0, 1), (2, 3)])
|
|
[tensor([0, 2]), tensor([1, 3])]
|
|
>>> # Example with `List` inside the batch:
|
|
>>> default_collate([[0, 1], [2, 3]])
|
|
[tensor([0, 2]), tensor([1, 3])]
|
|
>>> # Two options to extend `default_collate` to handle specific type
|
|
>>> # Option 1: Write custom collate function and invoke `default_collate`
|
|
>>> def custom_collate(batch):
|
|
... elem = batch[0]
|
|
... if isinstance(elem, CustomType): # Some custom condition
|
|
... return ...
|
|
... else: # Fall back to `default_collate`
|
|
... return default_collate(batch)
|
|
>>> # Option 2: In-place modify `default_collate_fn_map`
|
|
>>> def collate_customtype_fn(batch, *, collate_fn_map=None):
|
|
... return ...
|
|
>>> default_collate_fn_map.update(CustomType, collate_customtype_fn)
|
|
>>> default_collate(batch) # Handle `CustomType` automatically
|
|
"""
|
|
return collate(batch, collate_fn_map=default_collate_fn_map)
|