Prevent hanging in data loader altogether

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11985

Differential Revision: D10202374

Pulled By: SsnL

fbshipit-source-id: 1ab1a07185f78a104f9b05930a87ef5a32f431e4
This commit is contained in:
Tongzhou Wang
2018-10-09 09:51:42 -07:00
committed by Facebook Github Bot
parent 1a0d82e4f4
commit 11c31aef04
4 changed files with 522 additions and 140 deletions

View File

@ -91,7 +91,8 @@ TEST_MKL = torch.backends.mkl.is_available()
# TODO: allow Py2 when librosa 0.6.2 releases
TEST_LIBROSA = _check_module_exists('librosa') and PY3
NO_MULTIPROCESSING_SPAWN = os.environ.get('NO_MULTIPROCESSING_SPAWN', '0') == '1'
# Python 2.7 doesn't have spawn
NO_MULTIPROCESSING_SPAWN = os.environ.get('NO_MULTIPROCESSING_SPAWN', '0') == '1' or sys.version_info[0] == 2
TEST_WITH_ASAN = os.getenv('PYTORCH_TEST_WITH_ASAN', '0') == '1'
TEST_WITH_UBSAN = os.getenv('PYTORCH_TEST_WITH_UBSAN', '0') == '1'
TEST_WITH_ROCM = os.getenv('PYTORCH_TEST_WITH_ROCM', '0') == '1'

View File

@ -5,25 +5,25 @@ import os
import ctypes
import signal
import torch
import gc
import time
import traceback
import unittest
import subprocess
import itertools
from torch import multiprocessing as mp
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
from torch.utils.data.dataset import random_split
from torch.utils.data.dataloader import default_collate, ExceptionWrapper, MANAGER_STATUS_CHECK_INTERVAL
from torch.utils.data.dataloader import default_collate, ExceptionWrapper, MP_STATUS_CHECK_INTERVAL
from common import TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm
# We cannot import TEST_CUDA from common_nn here, because if we do that,
# the TEST_CUDNN line from common_nn will be executed multiple times
# We cannot import TEST_CUDA from common_cuda here, because if we do that,
# the TEST_CUDNN line from common_cuda will be executed multiple times
# as well during the execution of this test suite, and it will cause
# CUDA OOM error on Windows.
TEST_CUDA = torch.cuda.is_available()
# We need spawn start method for test_manager_unclean_exit, but
# Python 2.7 doesn't allow it.
if sys.version_info[0] == 3:
if not NO_MULTIPROCESSING_SPAWN:
# Get a multiprocessing context because some test / third party library will
# set start_method when imported, and setting again triggers RuntimeError.
mp = mp.get_context(method='spawn')
@ -149,15 +149,12 @@ class ErrorTrackingProcess(mp.Process):
self._exception = None
def run(self):
# Disable stderr printing from os level, and make workers not printing
# to stderr.
# Can't use sys.stderr.close, otherwise Python `raise` will error with
# ValueError: I/O operation on closed file.
os.close(sys.stderr.fileno())
# Disable polluting stderr with errors that are supposed to happen.
sys.stderr = open(os.devnull, "w")
try:
super(ErrorTrackingProcess, self).run()
self._cconn.send(None)
except Exception as e:
except Exception:
self._cconn.send(ExceptionWrapper(sys.exc_info()))
raise
@ -259,12 +256,94 @@ def _test_timeout():
_ = next(iter(dataloader))
def _test_timeout_pin_memory():
dataset = SleepDataset(10, 3)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1, pin_memory=True)
_ = next(iter(dataloader))
def disable_stderr(worker_id):
r"""
Avoids printing "ERROR: Unexpected segmentation fault encountered in worker."
from workers. Since worker signal handler prints with low-level write(),
this has to be done on OS level via dup.
This is used as worker_init_fn for test_segfault.
"""
sys.stderr.flush() # flush library buffers that dup2 knows nothing about
# Can't use a with-block because otherwise the fd will be closed when this
# function ends.
devnull = open(os.devnull, 'w')
os.dup2(devnull.fileno(), sys.stderr.fileno())
def _test_segfault():
dataset = SegfaultDataset(10)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr)
_ = next(iter(dataloader))
class TestProperExitDataset(object):
def __init__(self, size, error_event):
self.size = size
self.error_event = error_event
def __len__(self):
return self.size
def __getitem__(self, idx):
if self.error_event is not None and self.error_event.is_set():
raise RuntimeError('Worker error')
return torch.tensor([idx])
# See TestDataLoader.test_proper_exit for usage
def _test_proper_exit(use_workers, pin_memory, exit_method, hold_iter_reference,
worker_pids, setup_event):
num_workers = 2 if use_workers else 0
if exit_method == 'worker_error' or exit_method == 'worker_kill':
assert use_workers is True
ds = TestProperExitDataset(16, setup_event if exit_method == 'worker_error' else None)
loader = DataLoader(ds, batch_size=2, shuffle=False,
num_workers=num_workers, pin_memory=pin_memory)
it = iter(loader)
if use_workers:
for i, w in enumerate(it.workers):
worker_pids[i] = w.pid
error_it = 4
assert len(loader) > error_it
def kill_pid(pid):
if IS_WINDOWS:
os.system('taskkill /PID ' + str(os.getpid()) + ' /F')
else:
os.kill(os.getpid(), signal.SIGKILL)
for i, _ in enumerate(it):
if i == 0:
if not hold_iter_reference:
del it
setup_event.set()
if i == error_it:
if exit_method == 'main_error':
raise RuntimeError('Error')
elif exit_method == 'main_kill':
kill_pid(os.getpid())
elif exit_method == 'worker_kill':
kill_pid(worker_pids[0])
if not hold_iter_reference:
# Tries to trigger the __del__ clean-up rather than the automatic
# exiting of daemonic children. Technically it should be automatically
# triggered, but I don't want to rely on the implementation detail of
# Python gc.
gc.collect()
# test custom init function
def init_fn(worker_id):
torch.manual_seed(12345)
@ -373,7 +452,12 @@ class TestDataLoader(TestCase):
@skipIfRocm
def test_timeout(self):
p = ErrorTrackingProcess(target=_test_timeout)
if TEST_CUDA and not NO_MULTIPROCESSING_SPAWN:
targets = (_test_timeout, _test_timeout_pin_memory)
else:
targets = (_test_timeout,)
for target in targets:
p = ErrorTrackingProcess(target=target)
p.start()
p.join(JOIN_TIMEOUT)
try:
@ -506,10 +590,14 @@ class TestDataLoader(TestCase):
self._test_error(DataLoader(ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4))
@unittest.skipIf(IS_WINDOWS, "FIXME: stuck test")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_partial_workers(self):
r"""Check that workers exit even if the iterator is not exhausted."""
for pin_memory in (True, False):
if TEST_CUDA:
pin_memory_configs = (True, False)
else:
pin_memory_configs = (False,)
for pin_memory in pin_memory_configs:
loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory))
workers = loader.workers
if pin_memory:
@ -517,6 +605,7 @@ class TestDataLoader(TestCase):
for i, sample in enumerate(loader):
if i == 10:
break
assert i == 10
del loader
for w in workers:
w.join(JOIN_TIMEOUT)
@ -525,24 +614,6 @@ class TestDataLoader(TestCase):
pin_memory_thread.join(JOIN_TIMEOUT)
self.assertFalse(pin_memory_thread.is_alive())
@staticmethod
def _main_process(dataset, worker_pids, main_exit_event, raise_error):
loader = iter(DataLoader(dataset, batch_size=2, num_workers=4, pin_memory=True))
workers = loader.workers
for i in range(len(workers)):
worker_pids[i] = int(workers[i].pid)
for i, sample in enumerate(loader):
if i == 3:
# Simulate an exit of the manager process
main_exit_event.set()
if raise_error:
raise RuntimeError('Error')
else:
if IS_WINDOWS:
os.system('taskkill /PID ' + str(os.getpid()) + ' /F')
else:
os.kill(os.getpid(), signal.SIGKILL)
@staticmethod
def _is_process_alive(pid, pname):
# There is a chance of a terminated child process's pid being reused by a new unrelated process,
@ -558,51 +629,94 @@ class TestDataLoader(TestCase):
output = output.decode('utf-8')
return pname in output
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
don't support multiprocessing with spawn start method")
@unittest.skipIf(sys.version_info[0] == 2,
"spawn start method is not supported in Python 2, \
but we need it for creating another process with CUDA")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@skipIfRocm
def test_main_process_unclean_exit(self):
r'''There might be ConnectionResetError or leaked semaphore warning (due to dirty process exit), \
but they are all safe to ignore'''
def test_proper_exit(self):
r'''There might be ConnectionResetError or leaked semaphore warning
(due to dirty process exit), but they are all safe to ignore'''
# `raise_error` controls if the main process is KILL-ed by OS or just
# simply raises an error. Both cases are interesting because
# 1. In case of it is KILL-ed by OS, the workers need to automatically
# discover that their parent is dead and exit gracefully.
# 2. In case of it raises an error itself, the parent process needs to
# take care of exiting the worker and then exits itself gracefully.
for raise_error in (True, False):
worker_pids = mp.Array('i', [0] * 4)
# TODO: test the case where the pin_memory_thread triggers an
# error/fatal signal. I haven't found out how to properly do that.
main_exit_event = mp.Event()
p = mp.Process(target=TestDataLoader._main_process,
args=(self.dataset, worker_pids, main_exit_event, raise_error))
p.start()
worker_pids[-1] = p.pid
# Array to store the worker pids.
worker_pids = mp.Array('i', [-1 for _ in range(10)])
main_exit_event.wait()
exit_status = [False] * len(worker_pids)
def wait_pids(pids, timeout):
r"""Wait for all process specified in pids to exit in given timeout."""
exit_status = [False for _ in pids]
start_time = time.time()
pname = 'python'
while True:
for i in range(len(worker_pids)):
pid = worker_pids[i]
for i in range(len(pids)):
pid = pids[i]
if not exit_status[i]:
if not TestDataLoader._is_process_alive(pid, pname):
exit_status[i] = True
if all(exit_status):
break
else:
if time.time() - start_time > MANAGER_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT:
self.fail('subprocess not terminated')
time.sleep(1)
p.join(MANAGER_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT - (time.time() - start_time))
self.assertFalse(p.is_alive(), 'main process not terminated')
if time.time() - start_time > timeout:
break
time.sleep(0.5)
return exit_status
for use_workers, pin_memory, hold_iter_reference in itertools.product([True, False], repeat=3):
# `hold_iter_reference` specifies whether we hold a reference to the
# iterator. This is interesting because Python3 error traces holds a
# reference to the frames, which hold references to all the local
# variables including the iterator, and then the iterator dtor may
# not be called before process end. It is important to see that the
# processes still exit in both cases.
if pin_memory and (not TEST_CUDA or NO_MULTIPROCESSING_SPAWN):
# Can't use CUDA without spawn
continue
# `exit_method` controls the way the loader process ends.
# - `*_kill` means that `*` is killed by OS.
# - `*_error` means that `*` raises an error.
# - `None` means that no error happens.
# In all cases, all processes should end properly.
if use_workers:
exit_methods = [None, 'main_error', 'main_kill', 'worker_kill', 'worker_error']
else:
exit_methods = [None, 'main_error', 'main_kill']
for exit_method in exit_methods:
# clear pids array first
for i in range(len(worker_pids)):
worker_pids[i] = -1
# Event that the loader process uses to signal testing process
# that various things are setup, including that the worker pids
# are specified in `worker_pids` array.
setup_event = mp.Event()
p = ErrorTrackingProcess(target=_test_proper_exit,
args=(use_workers, pin_memory, exit_method,
hold_iter_reference, worker_pids, setup_event))
p.start()
# Wait for loader process to set everything up, i.e., filling
# worker pids in `worker_pids`.
setup_event.wait(timeout=JOIN_TIMEOUT)
self.assertTrue(setup_event.is_set(), 'loader process setup timed out')
pids = [pid for pid in worker_pids if pid > 0]
try:
exit_status = wait_pids(pids, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
if not all(exit_status):
self.fail('subprocess (pid(s) {}) not terminated'.format(
', '.join(p for p, exited in zip(pids, exit_status) if not exited)))
p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
self.assertFalse(p.is_alive(), 'loader process not terminated')
if exit_method is None:
self.assertEqual(p.exitcode, 0)
else:
self.assertNotEqual(p.exitcode, 0)
finally:
p.terminate()
def test_len(self):
def check_len(dl, expected):

View File

@ -1,15 +1,14 @@
#include "DataLoader.h"
// In cases like DataLoader, if a worker process die due to bus error/segfault
// or just hang, the main process, if implemented with
// multiprocessing.queue.SimpleQueue, will hang waiting for data. This is
// difficult to avoid on PyTorch side as it can be caused by limited shm, or
// other libraries users call in the workers. The following methods is an effort
// to do our best provide some error message to users when such unfortunate
// events happen.
// In cases like DataLoader, if a worker process dies due to bus error/segfault
// or just hang, the main process will hang waiting for data. This is difficult
// to avoid on PyTorch side as it can be caused by limited shm, or other
// libraries users call in the workers. The following methods is an effort to do
// our best to provide some error message to users when such unfortunate events
// happen.
// TODO: The following don't work on Windows. Specifically, sigaction, waitid
// calls ,and SIGCHLD handler. Currently, dummy implementations are provided
// calls, and SIGCHLD handler. Currently, dummy implementations are provided
// for Windows.
#ifndef _WIN32
@ -63,6 +62,7 @@ static inline void setSignalHandler(int signal, void(*handler)(int, siginfo_t *,
SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered in worker. "
"This might be caused by insufficient shared memory (shm).\n");
SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n");
SIGNAL_HANDLER(SIGFPE, handler_SIGFPE, "ERROR: Unexpected floating-point exception encountered in worker.\n");
// When an error happend in DataLoader methods and Python starts to exit, the
// error trace will keep the loader alive, and Python may kill the children
@ -92,6 +92,7 @@ static PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *a
setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr);
setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr);
setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr);
setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
@ -130,9 +131,7 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
} else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal
std::ostringstream oss;
oss << "DataLoader worker (pid " << worker_pid << ") is killed "
<< "by signal: " << strsignal(infop.si_status) << ". "
<< "Details are lost due to multiprocessing. Rerunning with "
<< "num_workers=0 may give better error trace.";
<< "by signal: " << strsignal(infop.si_status) << ". ";
// This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again.
pid_set->clear();

View File

@ -26,10 +26,20 @@ else:
import queue
# NOTE [ Python Traceback Reference Cycle Problem ]
#
# When using sys.exc_info(), it is important to **not** store the exc_info[2],
# which is the traceback, because otherwise you will run into the traceback
# reference cycle problem, i.e., the traceback holding reference to the frame,
# and the frame (which holds reference to all the object in its temporary scope)
# holding reference the traceback.
class ExceptionWrapper(object):
r"""Wraps an exception plus traceback to communicate across threads"""
def __init__(self, exc_info):
# It is important that we don't store exc_info, see
# NOTE [ Python Traceback Reference Cycle Problem ]
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info))
@ -37,7 +47,11 @@ class ExceptionWrapper(object):
_use_shared_memory = False
r"""Whether to use shared memory in default_collate"""
MANAGER_STATUS_CHECK_INTERVAL = 5.0
MP_STATUS_CHECK_INTERVAL = 5.0
r"""Interval (in seconds) to check status of processes to avoid hanging in
multiprocessing data loading. This is mainly used in getting data from
another process, in which case we need to periodically check whether the
sender is alive to prevent hanging."""
if IS_WINDOWS:
# On Windows, the parent ID of the worker process remains unchanged when the manager process
@ -60,19 +74,29 @@ if IS_WINDOWS:
if not self.manager_handle:
raise ctypes.WinError(ctypes.get_last_error())
self.manager_dead = False
def is_alive(self):
if not self.manager_dead:
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0
self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
return not self.manager_dead
else:
class ManagerWatchdog(object):
def __init__(self):
self.manager_pid = os.getppid()
self.manager_dead = False
def is_alive(self):
return os.getppid() == self.manager_pid
if not self.manager_dead:
self.manager_dead = os.getppid() != self.manager_pid
return not self.manager_dead
def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
try:
global _use_shared_memory
_use_shared_memory = True
@ -87,9 +111,6 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
random.seed(seed)
torch.manual_seed(seed)
# Do not wait for putting thread to join when this worker exits.
# Otherwise, this worker may always be waiting to put and doesn't check
# index_queue and done_event for termination signal.
data_queue.cancel_join_thread()
if init_fn is not None:
@ -97,22 +118,26 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
watchdog = ManagerWatchdog()
while True:
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
if watchdog.is_alive() and not done_event.is_set():
continue
else:
break
# use done_event so that we can get faster exiting signal even if there
# are still indices in index_queue
if r is None or done_event.is_set():
break
if r is None:
# Received the final signal
assert done_event.is_set()
return
elif done_event.is_set():
# Done event is set. But I haven't received the final signal
# (None) yet. I will keep continuing until get it, and skip the
# processing steps.
continue
idx, batch_indices = r
try:
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
# It is important that we don't store exc_info in a variable,
# see NOTE [ Python Traceback Reference Cycle Problem ]
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
@ -122,25 +147,33 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
pass
def _pin_memory_loop(in_queue, out_queue, done_event, pin_memory, device_id):
if pin_memory:
def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
torch.cuda.set_device(device_id)
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
while True:
try:
r = in_queue.get()
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
except Exception:
if done_event.is_set():
return
raise
if r is None or done_event.is_set():
# Weird things can happen when shutting down, e.g., fd being
# closed when tensors are shared via fds.
break
if isinstance(r[1], ExceptionWrapper):
out_queue.put(r)
raise
if r is None:
assert done_event.is_set()
return
elif done_event.is_set():
# Haven't seen the final signal yet. Keep getting until None.
continue
elif isinstance(r[1], ExceptionWrapper):
out_queue.put(r)
else:
idx, batch = r
try:
if pin_memory:
batch = pin_memory_batch(batch)
except Exception:
out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
@ -230,6 +263,8 @@ def _set_SIGCHLD_handler():
return
previous_handler = signal.getsignal(signal.SIGCHLD)
if not callable(previous_handler):
# This doesn't catch default handler, but SIGCHLD default handler is a
# no-op.
previous_handler = None
def handler(signum, frame):
@ -246,6 +281,207 @@ def _set_SIGCHLD_handler():
class _DataLoaderIter(object):
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
# NOTE [ Data Loader Multiprocessing Shutdown Logic ]
#
# Preliminary:
#
# Our data model looks like this (queues are indicated with curly brackets):
#
# main process ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {data_queue} ||
# | ||
# data output \/
#
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
# `pin_memory=False`.
#
#
# Terminating multiprocessing logic requires very careful design. In
# particular, we need to make sure that
#
# 1. The iterator gracefully exits the workers when its last reference is
# gone.
#
# In this case, the workers should be gracefully exited because the
# main process may still need to continue to run, and we want cleaning
# up code in the workers to be executed (e.g., releasing GPU memory).
# Naturally, we implement the shutdown logic in `__del__` of
# DataLoaderIterator.
#
# We delay the discussion on the logic in this case until later.
#
# 2. The iterator exits the workers when the loader process and/or worker
# processes exits unexpectedly (e.g., SIGKILL-ed).
#
# We set all workers and `pin_memory_thread` to have `daemon=True`.
#
# You may ask, why can't we make the workers non-daemonic, and
# gracefully exit using the same logic as we have in `__del__` when the
# iterator gets deleted (see 1 above)?
#
# When a process ends, it shuts the all its daemonic children down with
# a SIGTERM (instead of joining them without a timeout). Simiarly for
# threads, but by a different mechanism. This fact, together with a few
# implementation details of multiprocessing, forces us to make workers
# daemonic. All of our problems arise when a DataLoader is used in a
# subprocess, and are caused by multiprocessing code which looks more
# or less like this:
#
# try:
# your_function_using_a_dataloader()
# finally:
# multiprocessing.util._exit_function()
#
# The joining/termination mentioned above happens inside
# `_exit_function()`. Now, if `your_function_using_a_dataloader()`
# throws, the stack trace stored in the exception will prevent the
# frame which uses `DataLoaderIter` to be freed. If the frame has any
# reference to the `DataLoaderIter` (e.g., in a method of the iter),
# its `__del__`, which starts the shutdown procedure, will not be
# called. That, in turn, means that workers aren't notified. Attempting
# to join in `_exit_function` will then result in a hang.
#
# For context, `_exit_function` is also registered as an `atexit` call.
# So it is unclear to me (@ssnl) why this is needed in a finally block.
# The code dates back to 2008 and there is no comment on the original
# PEP 371 or patch https://bugs.python.org/issue3050 (containing both
# the finally block and the `atexit` registration) that explains this.
#
# Another choice is to just shutdown workers with logic in 1 above
# whenever we see an error in `next`. This isn't ideal because
# a. It prevents users from using try-catch to resume data loading.
# b. It doesn't prevent hanging if users have references to the
# iterator.
#
# 3. All processes exit if any of them die unexpectedly (e.g., error,
# fatal signals).
#
# As shown above, the workers are set as daemonic children of the main
# process. However, automatic cleaning-up of such child processes only
# happens if the parent process exits gracefully (e.g., not via fatal
# signals like SIGKILL). So we must ensure that each process will exit
# even the process that should send/receive data to/from it were
# killed, i.e.,
#
# a. A process won't hang when getting from a queue.
#
# Even with carefully designed data dependencies (i.e., a `put()`
# always corresponding to a `get()`), hanging on `get()` can still
# happen when data in queue is corrupted (e.g., due to
# `cancel_join_thread` or unexpected exit).
#
# For child exit, we register SIGCHLD handler on main process,
# which checks if any of the workers fail in the (Python) handler.
# See DataLoader.cpp.
#
# For `.get()` calls where the sender(s) is not the workers, we
# guard them with timeouts, and check the status of the sender
# when timeout happens:
# + in the workers, the `ManagerWatchdog` class checks the main
# process status.
# + if `pin_memory=True`, when getting from `pin_memory_thread`,
# check `pin_memory_thread` status periodically until `.get()`
# returns or see that `pin_memory_thread` died.
#
# b. A process won't hang when putting into a queue;
#
# We use `mp.Queue` which has a separate background thread to put
# objects from an unbounded buffer array. The background thread is
# daemonic and usually automatically joined when the process
# exits.
#
# However, in case that the receiver has ended abruptly while
# reading from the pipe, the join will hang forever. Therefore,
# for both `worker_result_queue` (worker -> main process/pin_memory_thread)
# and each `index_queue` (main process -> worker), we use
# `q.cancel_join_thread()` in sender process before any `q.put` to
# prevent this automatic join.
#
# Moreover, having all queues called `cancel_join_thread` makes
# implementing graceful shutdown logic in `__del__` much easier.
# It won't need to get from any queue, which would also need to be
# guarded by periodic status checks.
#
# Note that this may leave corrupted data in the queue, but we
# don't care about the data anyways once we are shutting down.
#
#
# Now let's get back to 1:
# how we gracefully exit the workers when the last reference to the
# iteartor is gone.
#
# To achieve this, we implement the following logic along with the design
# choices mentioned above:
#
# [worker processes]
# While loader process is alive:
# Get from index_queue.
# If got a `None`, exit.
# If get anything else,
# Check `done_event`.
# If set, continue to next iteration
# i.e., keep getting until see the `None`, then exit.
# Otherwise, process data.
# If timed out,
# No matter `done_event` is set (still need to see `None`) or not,
# must continue to next iteration .
#
# [pin_memory_thread]
# # No need to check main thread. If this thread is alive, the main loader
# # thread must be alive, because this thread is set as daemonic.
# While True:
# Get from index_queue.
# If got a `None`, exit.
# If get anything else,
# Check `done_event`.
# If set, continue to next iteration
# i.e., keep getting until see the `None`, then exit.
# Otherwise, process data.
#
# NOTE: we don't check the status of the main thread because
# 1. if the process is killed by fatal signal, `pin_memory_thread`
# ends.
# 2. in other cases, either the cleaning-up in __del__ or the
# automatic exit of daemonic thread will take care of it.
# This won't busy-wait either because `.get(timeout)` does not
# busy-wait.
#
# [main process]
# In the DataLoader Iter's `__del__`
# a. Set `done_event` (shared with `pin_memory_thread` and workers).
#
# Note: from here on, the workers & `pin_memory_thread` may exit at
# any time after they receive `None`.
#
# b. Exit `pin_memory_thread`
# i. Put `None` in `worker_result_queue`.
# ii. Join the `pin_memory_thread`.
#
# c. Exit the workers.
# i. Put `None` in each worker's `index_queue`.
# ii. Join the workers.
#
# NOTE: This has to be after (b) because it may leave corrupted data
# in `worker_result_queue`, which `pin_memory_thread` reads
# from.
#
# NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
# can be omitted
#
# NB: `done_event`s isn't strictly needed. E.g., we can just check for
# `None` from `index_queue`, but it allows us to skip wasting resources
# processing indices already in `index_queue` if we are already shutting
# down.
def __init__(self, loader):
self.dataset = loader.dataset
self.collate_fn = loader.collate_fn
@ -274,18 +510,19 @@ class _DataLoaderIter(object):
self.workers = []
for i in range(self.num_workers):
index_queue = multiprocessing.Queue()
index_queue.cancel_join_thread()
w = multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, index_queue,
self.worker_result_queue, self.done_event,
self.collate_fn, base_seed + i,
self.worker_init_fn, i))
w.daemon = True # ensure that the worker exits on process exit
# Process.start() actually take some time as it needs to start a
# process and pass the arguments over via a pipe. Therefore, we
# only add a worker to self.workers list after it started, so
# that we do not call .join() if program dies before it starts,
# and __del__ tries to join it but will get:
w.daemon = True
# NB: Process.start() actually take some time as it needs to
# start a process and pass the arguments over via a pipe.
# Therefore, we only add a worker to self.workers list after
# it started, so that we do not call .join() if program dies
# before it starts, and __del__ tries to join but will get:
# AssertionError: can only join a started process.
w.start()
self.index_queues.append(index_queue)
@ -295,8 +532,8 @@ class _DataLoaderIter(object):
self.data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
target=_pin_memory_loop,
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
torch.cuda.current_device()))
args=(self.worker_result_queue, self.data_queue,
torch.cuda.current_device(), self.done_event))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
@ -317,11 +554,25 @@ class _DataLoaderIter(object):
return len(self.batch_sampler)
def _get_batch(self):
# In the non-timeout case, worker exit is covered by SIGCHLD handler.
# But if `pin_memory=True`, we still need account for the possibility
# that `pin_memory_thread` dies.
if self.timeout > 0:
try:
return self.data_queue.get(timeout=self.timeout)
except queue.Empty:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
elif self.pin_memory:
while self.pin_memory_thread.is_alive():
try:
return self.data_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
else:
# while condition is false, i.e., pin_memory_thread died.
raise RuntimeError('Pin memory thread exited unexpectedly')
# In this case, `self.data_queue` is a `queue.Queue`,. But we don't
# need to call `.task_done()` because we don't use `.join()`.
else:
return self.data_queue.get()
@ -383,29 +634,46 @@ class _DataLoaderIter(object):
raise NotImplementedError("_DataLoaderIter cannot be pickled")
def _shutdown_workers(self):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
# the logic of this function.
if not self.shutdown:
self.shutdown = True
# removes pids from the C side data structure first so worker
# Removes pids from the C side data structure first so worker
# termination afterwards won't trigger false positive error report.
if self.worker_pids_set:
_remove_worker_pids(id(self))
self.worker_pids_set = False
self.done_event.set()
if self.pin_memory:
# Sending `None` to `pin_memory_thread` must be before
# stopping worker processes because the workers may leave
# corrupted data in `worker_result_queue`, causing
# `pin_memory_thread` unable to read and terminate properly.
# Exit `pin_memory_thread` first because exiting workers may leave
# corrupted data in `worker_result_queue` which `pin_memory_thread`
# reads from.
if hasattr(self, 'pin_memory_thread'):
# Use hasattr in case error happens before we set the attribute.
# First time do `worker_result_queue.put` in this process.
# `cancel_join_thread` in case that `pin_memory_thread` exited.
self.worker_result_queue.cancel_join_thread()
self.worker_result_queue.put(None)
# Workers can't be waiting to put be cause their output queue
# is a multiprocessing.Queue and its .put is non-blocking.
# They can only be waiting to get, so we put `None` here.
self.pin_memory_thread.join()
# Indicate that no more data will be put on this queue by the
# current process. This **must** be called after
# `pin_memory_thread` is joined because that thread shares the
# same pipe handles with this loader thread. If the handle is
# closed, Py3 will error in this case, but Py2 will just time
# out even if there is data in the queue.
self.worker_result_queue.close()
# Exit workers now.
for q in self.index_queues:
q.put(None)
# Indicate that no more data will be put on this queue by the
# current process.
q.close()
for w in self.workers:
w.join()
if hasattr(self, 'pin_memory_thread'):
self.pin_memory_thread.join()
def __del__(self):
if self.num_workers > 0: