mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Resubmit #20698 which got messed up. Idea is that when PyTorch is used in a custom build environment (e.g. Facebook), it's useful to track usage of various APIs centrally. This PR introduces a simple very lightweight mechanism to do so - only first invocation of a trigger point would be logged. This is significantly more lightweight than #18235 and thus we can allow to put logging in e.g. TensorImpl. Also adds an initial list of trigger points. Trigger points are added in such a way that no static initialization triggers them, i.e. just linking with libtorch.so will not cause any logging. Further suggestions of what to log are welcomed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20745 Differential Revision: D15429196 Pulled By: dzhulgakov fbshipit-source-id: a5e41a709a65b7ebccc6b95f93854e583cf20aca
679 lines
32 KiB
Python
679 lines
32 KiB
Python
r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes.
|
|
|
|
To support these two classes, in `./_utils` we define many utility methods and
|
|
functions to be run in multiprocessing. E.g., the data loading worker loop is
|
|
in `./_utils/worker.py`.
|
|
"""
|
|
|
|
import torch
|
|
import torch.multiprocessing as multiprocessing
|
|
from . import SequentialSampler, RandomSampler, BatchSampler
|
|
from . import _utils
|
|
import threading
|
|
from torch._six import queue
|
|
|
|
|
|
# This function used to be defined in this file. However, it was moved to
|
|
# _utils/collate.py. Although it is rather hard to access this from user land
|
|
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
|
|
# probably is user code out there using it. This aliasing maintains BC in this
|
|
# aspect.
|
|
default_collate = _utils.collate.default_collate
|
|
|
|
|
|
class DataLoader(object):
|
|
r"""
|
|
Data loader. Combines a dataset and a sampler, and provides
|
|
single- or multi-process iterators over the dataset.
|
|
|
|
Arguments:
|
|
dataset (Dataset): dataset from which to load the data.
|
|
batch_size (int, optional): how many samples per batch to load
|
|
(default: ``1``).
|
|
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
|
at every epoch (default: ``False``).
|
|
sampler (Sampler, optional): defines the strategy to draw samples from
|
|
the dataset. If specified, ``shuffle`` must be False.
|
|
batch_sampler (Sampler, optional): like sampler, but returns a batch of
|
|
indices at a time. Mutually exclusive with :attr:`batch_size`,
|
|
:attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
|
|
num_workers (int, optional): how many subprocesses to use for data
|
|
loading. 0 means that the data will be loaded in the main process.
|
|
(default: ``0``)
|
|
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
|
|
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
|
|
into CUDA pinned memory before returning them. If your data elements
|
|
are a custom type, or your ``collate_fn`` returns a batch that is a custom type
|
|
see the example below.
|
|
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
|
if the dataset size is not divisible by the batch size. If ``False`` and
|
|
the size of dataset is not divisible by the batch size, then the last batch
|
|
will be smaller. (default: ``False``)
|
|
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
|
from workers. Should always be non-negative. (default: ``0``)
|
|
worker_init_fn (callable, optional): If not ``None``, this will be called on each
|
|
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
|
input, after seeding and before data loading. (default: ``None``)
|
|
|
|
.. note:: When ``num_workers != 0``, the corresponding worker processes are created each time
|
|
iterator for the DataLoader is obtained (as in when you call
|
|
``enumerate(dataloader,0)``).
|
|
At this point, the dataset, ``collate_fn`` and ``worker_init_fn`` are passed to each
|
|
worker, where they are used to access and initialize data based on the indices
|
|
queued up from the main process. This means that dataset access together with
|
|
its internal IO, transforms and collation runs in the worker, while any
|
|
shuffle randomization is done in the main process which guides loading by assigning
|
|
indices to load. Workers are shut down once the end of the iteration is reached.
|
|
|
|
Since workers rely on Python multiprocessing, worker launch behavior is different
|
|
on Windows compared to Unix. On Unix fork() is used as the default
|
|
muliprocessing start method, so child workers typically can access the dataset and
|
|
Python argument functions directly through the cloned address space. On Windows, another
|
|
interpreter is launched which runs your main script, followed by the internal
|
|
worker function that receives the dataset, collate_fn and other arguments
|
|
through Pickle serialization.
|
|
|
|
This separate serialization means that you should take two steps to ensure you
|
|
are compatible with Windows while using workers
|
|
(this also works equally well on Unix):
|
|
|
|
- Wrap most of you main script's code within ``if __name__ == '__main__':`` block,
|
|
to make sure it doesn't run again (most likely generating error) when each worker
|
|
process is launched. You can place your dataset and DataLoader instance creation
|
|
logic here, as it doesn't need to be re-executed in workers.
|
|
- Make sure that ``collate_fn``, ``worker_init_fn`` or any custom dataset code
|
|
is declared as a top level def, outside of that ``__main__`` check. This ensures
|
|
they are available in workers as well
|
|
(this is needed since functions are pickled as references only, not bytecode).
|
|
|
|
By default, each worker will have its PyTorch seed set to
|
|
``base_seed + worker_id``, where ``base_seed`` is a long generated
|
|
by main process using its RNG. However, seeds for other libraies
|
|
may be duplicated upon initializing workers (w.g., NumPy), causing
|
|
each worker to return identical random numbers. (See
|
|
:ref:`dataloader-workers-random-seed` section in FAQ.) You may
|
|
use :func:`torch.initial_seed()` to access the PyTorch seed for
|
|
each worker in :attr:`worker_init_fn`, and use it to set other
|
|
seeds before data loading.
|
|
|
|
.. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
|
|
unpicklable object, e.g., a lambda function.
|
|
|
|
The default memory pinning logic only recognizes Tensors and maps and iterables
|
|
containg Tensors. By default, if the pinning logic sees a batch that is a custom type
|
|
(which will occur if you have a ``collate_fn`` that returns a custom batch type),
|
|
or if each element of your batch is a custom type, the pinning logic will not
|
|
recognize them, and it will return that batch (or those elements)
|
|
without pinning the memory. To enable memory pinning for custom batch or data types,
|
|
define a ``pin_memory`` method on your custom type(s).
|
|
|
|
Example::
|
|
|
|
class SimpleCustomBatch:
|
|
def __init__(self, data):
|
|
transposed_data = list(zip(*data))
|
|
self.inp = torch.stack(transposed_data[0], 0)
|
|
self.tgt = torch.stack(transposed_data[1], 0)
|
|
|
|
def pin_memory(self):
|
|
self.inp = self.inp.pin_memory()
|
|
self.tgt = self.tgt.pin_memory()
|
|
return self
|
|
|
|
def collate_wrapper(batch):
|
|
return SimpleCustomBatch(batch)
|
|
|
|
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
|
|
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
|
|
dataset = TensorDataset(inps, tgts)
|
|
|
|
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
|
|
pin_memory=True)
|
|
|
|
for batch_ndx, sample in enumerate(loader):
|
|
print(sample.inp.is_pinned())
|
|
print(sample.tgt.is_pinned())
|
|
|
|
"""
|
|
|
|
__initialized = False
|
|
|
|
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
|
|
batch_sampler=None, num_workers=0, collate_fn=default_collate,
|
|
pin_memory=False, drop_last=False, timeout=0,
|
|
worker_init_fn=None):
|
|
torch._C._log_api_usage_once("python.data_loader")
|
|
self.dataset = dataset
|
|
self.batch_size = batch_size
|
|
self.num_workers = num_workers
|
|
self.collate_fn = collate_fn
|
|
self.pin_memory = pin_memory
|
|
self.drop_last = drop_last
|
|
self.timeout = timeout
|
|
self.worker_init_fn = worker_init_fn
|
|
|
|
if timeout < 0:
|
|
raise ValueError('timeout option should be non-negative')
|
|
|
|
if batch_sampler is not None:
|
|
if batch_size > 1 or shuffle or sampler is not None or drop_last:
|
|
raise ValueError('batch_sampler option is mutually exclusive '
|
|
'with batch_size, shuffle, sampler, and '
|
|
'drop_last')
|
|
self.batch_size = None
|
|
self.drop_last = None
|
|
|
|
if sampler is not None and shuffle:
|
|
raise ValueError('sampler option is mutually exclusive with '
|
|
'shuffle')
|
|
|
|
if self.num_workers < 0:
|
|
raise ValueError('num_workers option cannot be negative; '
|
|
'use num_workers=0 to disable multiprocessing.')
|
|
|
|
if batch_sampler is None:
|
|
if sampler is None:
|
|
if shuffle:
|
|
sampler = RandomSampler(dataset)
|
|
else:
|
|
sampler = SequentialSampler(dataset)
|
|
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
|
|
|
|
self.sampler = sampler
|
|
self.batch_sampler = batch_sampler
|
|
self.__initialized = True
|
|
|
|
def __setattr__(self, attr, val):
|
|
if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
|
|
raise ValueError('{} attribute should not be set after {} is '
|
|
'initialized'.format(attr, self.__class__.__name__))
|
|
|
|
super(DataLoader, self).__setattr__(attr, val)
|
|
|
|
def __iter__(self):
|
|
return _DataLoaderIter(self)
|
|
|
|
def __len__(self):
|
|
return len(self.batch_sampler)
|
|
|
|
|
|
class _DataLoaderIter(object):
|
|
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
|
|
|
|
# NOTE [ Data Loader Multiprocessing Shutdown Logic ]
|
|
#
|
|
# Preliminary:
|
|
#
|
|
# Our data model looks like this (queues are indicated with curly brackets):
|
|
#
|
|
# main process ||
|
|
# | ||
|
|
# {index_queue} ||
|
|
# | ||
|
|
# worker processes || DATA
|
|
# | ||
|
|
# {worker_result_queue} || FLOW
|
|
# | ||
|
|
# pin_memory_thread of main process || DIRECTION
|
|
# | ||
|
|
# {data_queue} ||
|
|
# | ||
|
|
# data output \/
|
|
#
|
|
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
|
|
# `pin_memory=False`.
|
|
#
|
|
#
|
|
# Terminating multiprocessing logic requires very careful design. In
|
|
# particular, we need to make sure that
|
|
#
|
|
# 1. The iterator gracefully exits the workers when its last reference is
|
|
# gone or it is depleted.
|
|
#
|
|
# In this case, the workers should be gracefully exited because the
|
|
# main process may still need to continue to run, and we want cleaning
|
|
# up code in the workers to be executed (e.g., releasing GPU memory).
|
|
# Naturally, we implement the shutdown logic in `__del__` of
|
|
# DataLoaderIterator.
|
|
#
|
|
# We delay the discussion on the logic in this case until later.
|
|
#
|
|
# 2. The iterator exits the workers when the loader process and/or worker
|
|
# processes exits normally or with error.
|
|
#
|
|
# We set all workers and `pin_memory_thread` to have `daemon=True`.
|
|
#
|
|
# You may ask, why can't we make the workers non-daemonic, and
|
|
# gracefully exit using the same logic as we have in `__del__` when the
|
|
# iterator gets deleted (see 1 above)?
|
|
#
|
|
# First of all, `__del__` is **not** guaranteed to be called when
|
|
# interpreter exits. Even if it is called, by the time it executes,
|
|
# many Python core library resources may alreay be freed, and even
|
|
# simple things like acquiring an internal lock of a queue may hang.
|
|
# Therefore, in this case, we actually need to prevent `__del__` from
|
|
# being executed, and rely on the automatic termination of daemonic
|
|
# children. Thus, we register an `atexit` hook that sets a global flag
|
|
# `_utils.python_exit_status`. Since `atexit` hooks are executed in the
|
|
# reverse order of registration, we are guaranteed that this flag is
|
|
# set before library resources we use are freed. (Hooks freeing those
|
|
# resources are registered at importing the Python core libraries at
|
|
# the top of this file.) So in `__del__`, we check if
|
|
# `_utils.python_exit_status` is set or `None` (freed), and perform
|
|
# no-op if so.
|
|
#
|
|
# Another problem with `__del__` is also related to the library cleanup
|
|
# calls. When a process ends, it shuts the all its daemonic children
|
|
# down with a SIGTERM (instead of joining them without a timeout).
|
|
# Simiarly for threads, but by a different mechanism. This fact,
|
|
# together with a few implementation details of multiprocessing, forces
|
|
# us to make workers daemonic. All of our problems arise when a
|
|
# DataLoader is used in a subprocess, and are caused by multiprocessing
|
|
# code which looks more or less like this:
|
|
#
|
|
# try:
|
|
# your_function_using_a_dataloader()
|
|
# finally:
|
|
# multiprocessing.util._exit_function()
|
|
#
|
|
# The joining/termination mentioned above happens inside
|
|
# `_exit_function()`. Now, if `your_function_using_a_dataloader()`
|
|
# throws, the stack trace stored in the exception will prevent the
|
|
# frame which uses `DataLoaderIter` to be freed. If the frame has any
|
|
# reference to the `DataLoaderIter` (e.g., in a method of the iter),
|
|
# its `__del__`, which starts the shutdown procedure, will not be
|
|
# called. That, in turn, means that workers aren't notified. Attempting
|
|
# to join in `_exit_function` will then result in a hang.
|
|
#
|
|
# For context, `_exit_function` is also registered as an `atexit` call.
|
|
# So it is unclear to me (@ssnl) why this is needed in a finally block.
|
|
# The code dates back to 2008 and there is no comment on the original
|
|
# PEP 371 or patch https://bugs.python.org/issue3050 (containing both
|
|
# the finally block and the `atexit` registration) that explains this.
|
|
#
|
|
# Another choice is to just shutdown workers with logic in 1 above
|
|
# whenever we see an error in `next`. This isn't ideal because
|
|
# a. It prevents users from using try-catch to resume data loading.
|
|
# b. It doesn't prevent hanging if users have references to the
|
|
# iterator.
|
|
#
|
|
# 3. All processes exit if any of them die unexpectedly by fatal signals.
|
|
#
|
|
# As shown above, the workers are set as daemonic children of the main
|
|
# process. However, automatic cleaning-up of such child processes only
|
|
# happens if the parent process exits gracefully (e.g., not via fatal
|
|
# signals like SIGKILL). So we must ensure that each process will exit
|
|
# even the process that should send/receive data to/from it were
|
|
# killed, i.e.,
|
|
#
|
|
# a. A process won't hang when getting from a queue.
|
|
#
|
|
# Even with carefully designed data dependencies (i.e., a `put()`
|
|
# always corresponding to a `get()`), hanging on `get()` can still
|
|
# happen when data in queue is corrupted (e.g., due to
|
|
# `cancel_join_thread` or unexpected exit).
|
|
#
|
|
# For child exit, we set a timeout whenever we try to get data
|
|
# from `data_queue`, and check the workers' status on each timeout
|
|
# and error.
|
|
# See `_DataLoaderiter._get_batch()` and
|
|
# `_DataLoaderiter._try_get_batch()` for details.
|
|
#
|
|
# Additionally, for child exit on non-Windows platforms, we also
|
|
# register a SIGCHLD handler (which is supported on Windows) on
|
|
# the main process, which checks if any of the workers fail in the
|
|
# (Python) handler. This is more efficient and faster in detecting
|
|
# worker failures, compared to only using the above mechanism.
|
|
# See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
|
|
#
|
|
# For `.get()` calls where the sender(s) is not the workers, we
|
|
# guard them with timeouts, and check the status of the sender
|
|
# when timeout happens:
|
|
# + in the workers, the `_utils.worker.ManagerWatchdog` class
|
|
# checks the status of the main process.
|
|
# + if `pin_memory=True`, when getting from `pin_memory_thread`,
|
|
# check `pin_memory_thread` status periodically until `.get()`
|
|
# returns or see that `pin_memory_thread` died.
|
|
#
|
|
# b. A process won't hang when putting into a queue;
|
|
#
|
|
# We use `mp.Queue` which has a separate background thread to put
|
|
# objects from an unbounded buffer array. The background thread is
|
|
# daemonic and usually automatically joined when the process
|
|
# exits.
|
|
#
|
|
# However, in case that the receiver has ended abruptly while
|
|
# reading from the pipe, the join will hang forever. Therefore,
|
|
# for both `worker_result_queue` (worker -> main process/pin_memory_thread)
|
|
# and each `index_queue` (main process -> worker), we use
|
|
# `q.cancel_join_thread()` in sender process before any `q.put` to
|
|
# prevent this automatic join.
|
|
#
|
|
# Moreover, having all queues called `cancel_join_thread` makes
|
|
# implementing graceful shutdown logic in `__del__` much easier.
|
|
# It won't need to get from any queue, which would also need to be
|
|
# guarded by periodic status checks.
|
|
#
|
|
# Note that this may leave corrupted data in the queue, but we
|
|
# don't care about the data anyways once we are shutting down.
|
|
#
|
|
#
|
|
# Now let's get back to 1:
|
|
# how we gracefully exit the workers when the last reference to the
|
|
# iterator is gone.
|
|
#
|
|
# To achieve this, we implement the following logic along with the design
|
|
# choices mentioned above:
|
|
#
|
|
# [worker processes]
|
|
# While loader process is alive:
|
|
# Get from index_queue.
|
|
# If got a `None`, exit.
|
|
# If get anything else,
|
|
# Check `done_event`.
|
|
# If set, continue to next iteration
|
|
# i.e., keep getting until see the `None`, then exit.
|
|
# Otherwise, process data.
|
|
# If timed out,
|
|
# No matter `done_event` is set (still need to see `None`) or not,
|
|
# must continue to next iteration .
|
|
#
|
|
# [pin_memory_thread]
|
|
# # No need to check main thread. If this thread is alive, the main loader
|
|
# # thread must be alive, because this thread is set as daemonic.
|
|
# While True:
|
|
# Get from index_queue.
|
|
# If got a `None`, exit.
|
|
# If get anything else,
|
|
# Check `done_event`.
|
|
# If set, continue to next iteration
|
|
# i.e., keep getting until see the `None`, then exit.
|
|
# Otherwise, process data.
|
|
#
|
|
# NOTE: we don't check the status of the main thread because
|
|
# 1. if the process is killed by fatal signal, `pin_memory_thread`
|
|
# ends.
|
|
# 2. in other cases, either the cleaning-up in __del__ or the
|
|
# automatic exit of daemonic thread will take care of it.
|
|
# This won't busy-wait either because `.get(timeout)` does not
|
|
# busy-wait.
|
|
#
|
|
# [main process]
|
|
# In the DataLoader Iter's `__del__`
|
|
# a. Set `done_event` (shared with `pin_memory_thread` and workers).
|
|
#
|
|
# Note: from here on, the workers & `pin_memory_thread` may exit at
|
|
# any time after they receive `None`.
|
|
#
|
|
# b. Exit `pin_memory_thread`
|
|
# i. Put `None` in `worker_result_queue`.
|
|
# ii. Join the `pin_memory_thread`.
|
|
#
|
|
# c. Exit the workers.
|
|
# i. Put `None` in each worker's `index_queue`.
|
|
# ii. Join the workers.
|
|
#
|
|
# NOTE: This has to be after (b) because it may leave corrupted data
|
|
# in `worker_result_queue`, which `pin_memory_thread` reads
|
|
# from.
|
|
#
|
|
# NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
|
|
# can be omitted
|
|
#
|
|
# NB: `done_event`s isn't strictly needed. E.g., we can just check for
|
|
# `None` from `index_queue`, but it allows us to skip wasting resources
|
|
# processing indices already in `index_queue` if we are already shutting
|
|
# down.
|
|
|
|
def __init__(self, loader):
|
|
self.dataset = loader.dataset
|
|
self.collate_fn = loader.collate_fn
|
|
self.batch_sampler = loader.batch_sampler
|
|
self.num_workers = loader.num_workers
|
|
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
|
|
self.timeout = loader.timeout
|
|
|
|
self.sample_iter = iter(self.batch_sampler)
|
|
|
|
base_seed = torch.LongTensor(1).random_().item()
|
|
|
|
if self.num_workers > 0:
|
|
self.worker_init_fn = loader.worker_init_fn
|
|
self.worker_queue_idx = 0
|
|
self.worker_result_queue = multiprocessing.Queue()
|
|
self.batches_outstanding = 0
|
|
self.worker_pids_set = False
|
|
self.shutdown = False
|
|
self.send_idx = 0
|
|
self.rcvd_idx = 0
|
|
self.reorder_dict = {}
|
|
self.done_event = multiprocessing.Event()
|
|
|
|
self.index_queues = []
|
|
self.workers = []
|
|
for i in range(self.num_workers):
|
|
index_queue = multiprocessing.Queue()
|
|
index_queue.cancel_join_thread()
|
|
w = multiprocessing.Process(
|
|
target=_utils.worker._worker_loop,
|
|
args=(self.dataset, index_queue,
|
|
self.worker_result_queue, self.done_event,
|
|
self.collate_fn, base_seed + i,
|
|
self.worker_init_fn, i))
|
|
w.daemon = True
|
|
# NB: Process.start() actually take some time as it needs to
|
|
# start a process and pass the arguments over via a pipe.
|
|
# Therefore, we only add a worker to self.workers list after
|
|
# it started, so that we do not call .join() if program dies
|
|
# before it starts, and __del__ tries to join but will get:
|
|
# AssertionError: can only join a started process.
|
|
w.start()
|
|
self.index_queues.append(index_queue)
|
|
self.workers.append(w)
|
|
|
|
if self.pin_memory:
|
|
self.data_queue = queue.Queue()
|
|
pin_memory_thread = threading.Thread(
|
|
target=_utils.pin_memory._pin_memory_loop,
|
|
args=(self.worker_result_queue, self.data_queue,
|
|
torch.cuda.current_device(), self.done_event))
|
|
pin_memory_thread.daemon = True
|
|
pin_memory_thread.start()
|
|
# Similar to workers (see comment above), we only register
|
|
# pin_memory_thread once it is started.
|
|
self.pin_memory_thread = pin_memory_thread
|
|
else:
|
|
self.data_queue = self.worker_result_queue
|
|
|
|
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
|
|
_utils.signal_handling._set_SIGCHLD_handler()
|
|
self.worker_pids_set = True
|
|
|
|
# prime the prefetch loop
|
|
for _ in range(2 * self.num_workers):
|
|
self._put_indices()
|
|
|
|
def __len__(self):
|
|
return len(self.batch_sampler)
|
|
|
|
def _try_get_batch(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
|
|
# Tries to fetch data from `data_queue` for a given timeout. This can
|
|
# also be used as inner loop of fetching without timeout, with the
|
|
# sender status as the loop condition.
|
|
#
|
|
# This raises a `RuntimeError` if any worker died expectedly. This error
|
|
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
|
|
# (only for non-Windows platforms), or the manual check below on errors
|
|
# and timeouts.
|
|
#
|
|
# Returns a 2-tuple:
|
|
# (bool: whether successfully get data, any: data if successful else None)
|
|
try:
|
|
data = self.data_queue.get(timeout=timeout)
|
|
return (True, data)
|
|
except Exception as e:
|
|
# At timeout and error, we manually check whether any worker has
|
|
# failed. Note that this is the only mechanism for Windows to detect
|
|
# worker failures.
|
|
if not all(w.is_alive() for w in self.workers):
|
|
pids_str = ', '.join(str(w.pid) for w in self.workers if not w.is_alive())
|
|
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
|
|
if isinstance(e, queue.Empty):
|
|
return (False, None)
|
|
raise
|
|
|
|
def _get_batch(self):
|
|
# Fetches data from `self.data_queue`.
|
|
#
|
|
# We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
|
|
# which we achieve by running `self._try_get_batch(timeout=MP_STATUS_CHECK_INTERVAL)`
|
|
# in a loop. This is the only mechanism to detect worker failures for
|
|
# Windows. For other platforms, a SIGCHLD handler is also used for
|
|
# worker failure detection.
|
|
#
|
|
# If `pin_memory=True`, we also need check if `pin_memory_thread` had
|
|
# died at timeouts.
|
|
if self.timeout > 0:
|
|
success, data = self._try_get_batch(self.timeout)
|
|
if success:
|
|
return data
|
|
else:
|
|
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
|
|
elif self.pin_memory:
|
|
while self.pin_memory_thread.is_alive():
|
|
success, data = self._try_get_batch()
|
|
if success:
|
|
return data
|
|
else:
|
|
# while condition is false, i.e., pin_memory_thread died.
|
|
raise RuntimeError('Pin memory thread exited unexpectedly')
|
|
# In this case, `self.data_queue` is a `queue.Queue`,. But we don't
|
|
# need to call `.task_done()` because we don't use `.join()`.
|
|
else:
|
|
while True:
|
|
success, data = self._try_get_batch()
|
|
if success:
|
|
return data
|
|
|
|
def __next__(self):
|
|
if self.num_workers == 0: # same-process loading
|
|
indices = next(self.sample_iter) # may raise StopIteration
|
|
batch = self.collate_fn([self.dataset[i] for i in indices])
|
|
if self.pin_memory:
|
|
batch = _utils.pin_memory.pin_memory_batch(batch)
|
|
return batch
|
|
|
|
# check if the next sample has already been generated
|
|
if self.rcvd_idx in self.reorder_dict:
|
|
batch = self.reorder_dict.pop(self.rcvd_idx)
|
|
return self._process_next_batch(batch)
|
|
|
|
if self.batches_outstanding == 0:
|
|
self._shutdown_workers()
|
|
raise StopIteration
|
|
|
|
while True:
|
|
assert (not self.shutdown and self.batches_outstanding > 0)
|
|
idx, batch = self._get_batch()
|
|
self.batches_outstanding -= 1
|
|
if idx != self.rcvd_idx:
|
|
# store out-of-order samples
|
|
self.reorder_dict[idx] = batch
|
|
continue
|
|
return self._process_next_batch(batch)
|
|
|
|
next = __next__ # Python 2 compatibility
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def _put_indices(self):
|
|
assert self.batches_outstanding < 2 * self.num_workers
|
|
indices = next(self.sample_iter, None)
|
|
if indices is None:
|
|
return
|
|
self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
|
|
self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
|
|
self.batches_outstanding += 1
|
|
self.send_idx += 1
|
|
|
|
def _process_next_batch(self, batch):
|
|
self.rcvd_idx += 1
|
|
self._put_indices()
|
|
if isinstance(batch, _utils.ExceptionWrapper):
|
|
# make multiline KeyError msg readable by working around
|
|
# a python bug https://bugs.python.org/issue2651
|
|
if batch.exc_type == KeyError and "\n" in batch.exc_msg:
|
|
raise Exception("KeyError:" + batch.exc_msg)
|
|
else:
|
|
raise batch.exc_type(batch.exc_msg)
|
|
return batch
|
|
|
|
def __getstate__(self):
|
|
# TODO: add limited pickling support for sharing an iterator
|
|
# across multiple threads for HOGWILD.
|
|
# Probably the best way to do this is by moving the sample pushing
|
|
# to a separate thread and then just sharing the data queue
|
|
# but signalling the end is tricky without a non-blocking API
|
|
raise NotImplementedError("_DataLoaderIter cannot be pickled")
|
|
|
|
def _shutdown_workers(self):
|
|
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
|
|
# the logic of this function.
|
|
python_exit_status = _utils.python_exit_status
|
|
if python_exit_status is True or python_exit_status is None:
|
|
# See (2) of the note. If Python is shutting down, do no-op.
|
|
return
|
|
# Normal exit when last reference is gone / iterator is depleted.
|
|
# See (1) and the second half of the note.
|
|
if not self.shutdown:
|
|
self.shutdown = True
|
|
try:
|
|
self.done_event.set()
|
|
|
|
# Exit `pin_memory_thread` first because exiting workers may leave
|
|
# corrupted data in `worker_result_queue` which `pin_memory_thread`
|
|
# reads from.
|
|
if hasattr(self, 'pin_memory_thread'):
|
|
# Use hasattr in case error happens before we set the attribute.
|
|
# First time do `worker_result_queue.put` in this process.
|
|
|
|
# `cancel_join_thread` in case that `pin_memory_thread` exited.
|
|
self.worker_result_queue.cancel_join_thread()
|
|
self.worker_result_queue.put(None)
|
|
self.pin_memory_thread.join()
|
|
# Indicate that no more data will be put on this queue by the
|
|
# current process. This **must** be called after
|
|
# `pin_memory_thread` is joined because that thread shares the
|
|
# same pipe handles with this loader thread. If the handle is
|
|
# closed, Py3 will error in this case, but Py2 will just time
|
|
# out even if there is data in the queue.
|
|
self.worker_result_queue.close()
|
|
|
|
# Exit workers now.
|
|
for q in self.index_queues:
|
|
q.put(None)
|
|
# Indicate that no more data will be put on this queue by the
|
|
# current process.
|
|
q.close()
|
|
for w in self.workers:
|
|
w.join()
|
|
finally:
|
|
# Even though all this function does is putting into queues that
|
|
# we have called `cancel_join_thread` on, weird things can
|
|
# happen when a worker is killed by a signal, e.g., hanging in
|
|
# `Event.set()`. So we need to guard this with SIGCHLD handler,
|
|
# and remove pids from the C side data structure only at the
|
|
# end.
|
|
#
|
|
# FIXME: Unfortunately, for Windows, we are missing a worker
|
|
# error detection mechanism here in this function, as it
|
|
# doesn't provide a SIGCHLD handler.
|
|
if self.worker_pids_set:
|
|
_utils.signal_handling._remove_worker_pids(id(self))
|
|
self.worker_pids_set = False
|
|
|
|
def __del__(self):
|
|
if self.num_workers > 0:
|
|
self._shutdown_workers()
|