Prevent hanging in data loader altogether

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11985

Differential Revision: D10202374

Pulled By: SsnL

fbshipit-source-id: 1ab1a07185f78a104f9b05930a87ef5a32f431e4
This commit is contained in:
Tongzhou Wang
2018-10-09 09:51:42 -07:00
committed by Facebook Github Bot
parent 1a0d82e4f4
commit 11c31aef04
4 changed files with 522 additions and 140 deletions

View File

@ -91,7 +91,8 @@ TEST_MKL = torch.backends.mkl.is_available()
# TODO: allow Py2 when librosa 0.6.2 releases # TODO: allow Py2 when librosa 0.6.2 releases
TEST_LIBROSA = _check_module_exists('librosa') and PY3 TEST_LIBROSA = _check_module_exists('librosa') and PY3
NO_MULTIPROCESSING_SPAWN = os.environ.get('NO_MULTIPROCESSING_SPAWN', '0') == '1' # Python 2.7 doesn't have spawn
NO_MULTIPROCESSING_SPAWN = os.environ.get('NO_MULTIPROCESSING_SPAWN', '0') == '1' or sys.version_info[0] == 2
TEST_WITH_ASAN = os.getenv('PYTORCH_TEST_WITH_ASAN', '0') == '1' TEST_WITH_ASAN = os.getenv('PYTORCH_TEST_WITH_ASAN', '0') == '1'
TEST_WITH_UBSAN = os.getenv('PYTORCH_TEST_WITH_UBSAN', '0') == '1' TEST_WITH_UBSAN = os.getenv('PYTORCH_TEST_WITH_UBSAN', '0') == '1'
TEST_WITH_ROCM = os.getenv('PYTORCH_TEST_WITH_ROCM', '0') == '1' TEST_WITH_ROCM = os.getenv('PYTORCH_TEST_WITH_ROCM', '0') == '1'

View File

@ -5,25 +5,25 @@ import os
import ctypes import ctypes
import signal import signal
import torch import torch
import gc
import time import time
import traceback import traceback
import unittest import unittest
import subprocess import subprocess
import itertools
from torch import multiprocessing as mp from torch import multiprocessing as mp
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
from torch.utils.data.dataset import random_split from torch.utils.data.dataset import random_split
from torch.utils.data.dataloader import default_collate, ExceptionWrapper, MANAGER_STATUS_CHECK_INTERVAL from torch.utils.data.dataloader import default_collate, ExceptionWrapper, MP_STATUS_CHECK_INTERVAL
from common import TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm from common import TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm
# We cannot import TEST_CUDA from common_nn here, because if we do that, # We cannot import TEST_CUDA from common_cuda here, because if we do that,
# the TEST_CUDNN line from common_nn will be executed multiple times # the TEST_CUDNN line from common_cuda will be executed multiple times
# as well during the execution of this test suite, and it will cause # as well during the execution of this test suite, and it will cause
# CUDA OOM error on Windows. # CUDA OOM error on Windows.
TEST_CUDA = torch.cuda.is_available() TEST_CUDA = torch.cuda.is_available()
# We need spawn start method for test_manager_unclean_exit, but if not NO_MULTIPROCESSING_SPAWN:
# Python 2.7 doesn't allow it.
if sys.version_info[0] == 3:
# Get a multiprocessing context because some test / third party library will # Get a multiprocessing context because some test / third party library will
# set start_method when imported, and setting again triggers RuntimeError. # set start_method when imported, and setting again triggers RuntimeError.
mp = mp.get_context(method='spawn') mp = mp.get_context(method='spawn')
@ -149,15 +149,12 @@ class ErrorTrackingProcess(mp.Process):
self._exception = None self._exception = None
def run(self): def run(self):
# Disable stderr printing from os level, and make workers not printing # Disable polluting stderr with errors that are supposed to happen.
# to stderr. sys.stderr = open(os.devnull, "w")
# Can't use sys.stderr.close, otherwise Python `raise` will error with
# ValueError: I/O operation on closed file.
os.close(sys.stderr.fileno())
try: try:
super(ErrorTrackingProcess, self).run() super(ErrorTrackingProcess, self).run()
self._cconn.send(None) self._cconn.send(None)
except Exception as e: except Exception:
self._cconn.send(ExceptionWrapper(sys.exc_info())) self._cconn.send(ExceptionWrapper(sys.exc_info()))
raise raise
@ -259,12 +256,94 @@ def _test_timeout():
_ = next(iter(dataloader)) _ = next(iter(dataloader))
def _test_timeout_pin_memory():
dataset = SleepDataset(10, 3)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1, pin_memory=True)
_ = next(iter(dataloader))
def disable_stderr(worker_id):
r"""
Avoids printing "ERROR: Unexpected segmentation fault encountered in worker."
from workers. Since worker signal handler prints with low-level write(),
this has to be done on OS level via dup.
This is used as worker_init_fn for test_segfault.
"""
sys.stderr.flush() # flush library buffers that dup2 knows nothing about
# Can't use a with-block because otherwise the fd will be closed when this
# function ends.
devnull = open(os.devnull, 'w')
os.dup2(devnull.fileno(), sys.stderr.fileno())
def _test_segfault(): def _test_segfault():
dataset = SegfaultDataset(10) dataset = SegfaultDataset(10)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2) dataloader = DataLoader(dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr)
_ = next(iter(dataloader)) _ = next(iter(dataloader))
class TestProperExitDataset(object):
def __init__(self, size, error_event):
self.size = size
self.error_event = error_event
def __len__(self):
return self.size
def __getitem__(self, idx):
if self.error_event is not None and self.error_event.is_set():
raise RuntimeError('Worker error')
return torch.tensor([idx])
# See TestDataLoader.test_proper_exit for usage
def _test_proper_exit(use_workers, pin_memory, exit_method, hold_iter_reference,
worker_pids, setup_event):
num_workers = 2 if use_workers else 0
if exit_method == 'worker_error' or exit_method == 'worker_kill':
assert use_workers is True
ds = TestProperExitDataset(16, setup_event if exit_method == 'worker_error' else None)
loader = DataLoader(ds, batch_size=2, shuffle=False,
num_workers=num_workers, pin_memory=pin_memory)
it = iter(loader)
if use_workers:
for i, w in enumerate(it.workers):
worker_pids[i] = w.pid
error_it = 4
assert len(loader) > error_it
def kill_pid(pid):
if IS_WINDOWS:
os.system('taskkill /PID ' + str(os.getpid()) + ' /F')
else:
os.kill(os.getpid(), signal.SIGKILL)
for i, _ in enumerate(it):
if i == 0:
if not hold_iter_reference:
del it
setup_event.set()
if i == error_it:
if exit_method == 'main_error':
raise RuntimeError('Error')
elif exit_method == 'main_kill':
kill_pid(os.getpid())
elif exit_method == 'worker_kill':
kill_pid(worker_pids[0])
if not hold_iter_reference:
# Tries to trigger the __del__ clean-up rather than the automatic
# exiting of daemonic children. Technically it should be automatically
# triggered, but I don't want to rely on the implementation detail of
# Python gc.
gc.collect()
# test custom init function # test custom init function
def init_fn(worker_id): def init_fn(worker_id):
torch.manual_seed(12345) torch.manual_seed(12345)
@ -373,7 +452,12 @@ class TestDataLoader(TestCase):
@skipIfRocm @skipIfRocm
def test_timeout(self): def test_timeout(self):
p = ErrorTrackingProcess(target=_test_timeout) if TEST_CUDA and not NO_MULTIPROCESSING_SPAWN:
targets = (_test_timeout, _test_timeout_pin_memory)
else:
targets = (_test_timeout,)
for target in targets:
p = ErrorTrackingProcess(target=target)
p.start() p.start()
p.join(JOIN_TIMEOUT) p.join(JOIN_TIMEOUT)
try: try:
@ -506,10 +590,14 @@ class TestDataLoader(TestCase):
self._test_error(DataLoader(ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4)) self._test_error(DataLoader(ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4))
@unittest.skipIf(IS_WINDOWS, "FIXME: stuck test") @unittest.skipIf(IS_WINDOWS, "FIXME: stuck test")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_partial_workers(self): def test_partial_workers(self):
r"""Check that workers exit even if the iterator is not exhausted.""" r"""Check that workers exit even if the iterator is not exhausted."""
for pin_memory in (True, False): if TEST_CUDA:
pin_memory_configs = (True, False)
else:
pin_memory_configs = (False,)
for pin_memory in pin_memory_configs:
loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory)) loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory))
workers = loader.workers workers = loader.workers
if pin_memory: if pin_memory:
@ -517,6 +605,7 @@ class TestDataLoader(TestCase):
for i, sample in enumerate(loader): for i, sample in enumerate(loader):
if i == 10: if i == 10:
break break
assert i == 10
del loader del loader
for w in workers: for w in workers:
w.join(JOIN_TIMEOUT) w.join(JOIN_TIMEOUT)
@ -525,24 +614,6 @@ class TestDataLoader(TestCase):
pin_memory_thread.join(JOIN_TIMEOUT) pin_memory_thread.join(JOIN_TIMEOUT)
self.assertFalse(pin_memory_thread.is_alive()) self.assertFalse(pin_memory_thread.is_alive())
@staticmethod
def _main_process(dataset, worker_pids, main_exit_event, raise_error):
loader = iter(DataLoader(dataset, batch_size=2, num_workers=4, pin_memory=True))
workers = loader.workers
for i in range(len(workers)):
worker_pids[i] = int(workers[i].pid)
for i, sample in enumerate(loader):
if i == 3:
# Simulate an exit of the manager process
main_exit_event.set()
if raise_error:
raise RuntimeError('Error')
else:
if IS_WINDOWS:
os.system('taskkill /PID ' + str(os.getpid()) + ' /F')
else:
os.kill(os.getpid(), signal.SIGKILL)
@staticmethod @staticmethod
def _is_process_alive(pid, pname): def _is_process_alive(pid, pname):
# There is a chance of a terminated child process's pid being reused by a new unrelated process, # There is a chance of a terminated child process's pid being reused by a new unrelated process,
@ -558,51 +629,94 @@ class TestDataLoader(TestCase):
output = output.decode('utf-8') output = output.decode('utf-8')
return pname in output return pname in output
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
don't support multiprocessing with spawn start method")
@unittest.skipIf(sys.version_info[0] == 2,
"spawn start method is not supported in Python 2, \
but we need it for creating another process with CUDA")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@skipIfRocm @skipIfRocm
def test_main_process_unclean_exit(self): def test_proper_exit(self):
r'''There might be ConnectionResetError or leaked semaphore warning (due to dirty process exit), \ r'''There might be ConnectionResetError or leaked semaphore warning
but they are all safe to ignore''' (due to dirty process exit), but they are all safe to ignore'''
# `raise_error` controls if the main process is KILL-ed by OS or just # TODO: test the case where the pin_memory_thread triggers an
# simply raises an error. Both cases are interesting because # error/fatal signal. I haven't found out how to properly do that.
# 1. In case of it is KILL-ed by OS, the workers need to automatically
# discover that their parent is dead and exit gracefully.
# 2. In case of it raises an error itself, the parent process needs to
# take care of exiting the worker and then exits itself gracefully.
for raise_error in (True, False):
worker_pids = mp.Array('i', [0] * 4)
main_exit_event = mp.Event() # Array to store the worker pids.
p = mp.Process(target=TestDataLoader._main_process, worker_pids = mp.Array('i', [-1 for _ in range(10)])
args=(self.dataset, worker_pids, main_exit_event, raise_error))
p.start()
worker_pids[-1] = p.pid
main_exit_event.wait() def wait_pids(pids, timeout):
r"""Wait for all process specified in pids to exit in given timeout."""
exit_status = [False] * len(worker_pids) exit_status = [False for _ in pids]
start_time = time.time() start_time = time.time()
pname = 'python' pname = 'python'
while True: while True:
for i in range(len(worker_pids)): for i in range(len(pids)):
pid = worker_pids[i] pid = pids[i]
if not exit_status[i]: if not exit_status[i]:
if not TestDataLoader._is_process_alive(pid, pname): if not TestDataLoader._is_process_alive(pid, pname):
exit_status[i] = True exit_status[i] = True
if all(exit_status): if all(exit_status):
break break
else: else:
if time.time() - start_time > MANAGER_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT: if time.time() - start_time > timeout:
self.fail('subprocess not terminated') break
time.sleep(1) time.sleep(0.5)
p.join(MANAGER_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT - (time.time() - start_time)) return exit_status
self.assertFalse(p.is_alive(), 'main process not terminated')
for use_workers, pin_memory, hold_iter_reference in itertools.product([True, False], repeat=3):
# `hold_iter_reference` specifies whether we hold a reference to the
# iterator. This is interesting because Python3 error traces holds a
# reference to the frames, which hold references to all the local
# variables including the iterator, and then the iterator dtor may
# not be called before process end. It is important to see that the
# processes still exit in both cases.
if pin_memory and (not TEST_CUDA or NO_MULTIPROCESSING_SPAWN):
# Can't use CUDA without spawn
continue
# `exit_method` controls the way the loader process ends.
# - `*_kill` means that `*` is killed by OS.
# - `*_error` means that `*` raises an error.
# - `None` means that no error happens.
# In all cases, all processes should end properly.
if use_workers:
exit_methods = [None, 'main_error', 'main_kill', 'worker_kill', 'worker_error']
else:
exit_methods = [None, 'main_error', 'main_kill']
for exit_method in exit_methods:
# clear pids array first
for i in range(len(worker_pids)):
worker_pids[i] = -1
# Event that the loader process uses to signal testing process
# that various things are setup, including that the worker pids
# are specified in `worker_pids` array.
setup_event = mp.Event()
p = ErrorTrackingProcess(target=_test_proper_exit,
args=(use_workers, pin_memory, exit_method,
hold_iter_reference, worker_pids, setup_event))
p.start()
# Wait for loader process to set everything up, i.e., filling
# worker pids in `worker_pids`.
setup_event.wait(timeout=JOIN_TIMEOUT)
self.assertTrue(setup_event.is_set(), 'loader process setup timed out')
pids = [pid for pid in worker_pids if pid > 0]
try:
exit_status = wait_pids(pids, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
if not all(exit_status):
self.fail('subprocess (pid(s) {}) not terminated'.format(
', '.join(p for p, exited in zip(pids, exit_status) if not exited)))
p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
self.assertFalse(p.is_alive(), 'loader process not terminated')
if exit_method is None:
self.assertEqual(p.exitcode, 0)
else:
self.assertNotEqual(p.exitcode, 0)
finally:
p.terminate()
def test_len(self): def test_len(self):
def check_len(dl, expected): def check_len(dl, expected):

View File

@ -1,15 +1,14 @@
#include "DataLoader.h" #include "DataLoader.h"
// In cases like DataLoader, if a worker process die due to bus error/segfault // In cases like DataLoader, if a worker process dies due to bus error/segfault
// or just hang, the main process, if implemented with // or just hang, the main process will hang waiting for data. This is difficult
// multiprocessing.queue.SimpleQueue, will hang waiting for data. This is // to avoid on PyTorch side as it can be caused by limited shm, or other
// difficult to avoid on PyTorch side as it can be caused by limited shm, or // libraries users call in the workers. The following methods is an effort to do
// other libraries users call in the workers. The following methods is an effort // our best to provide some error message to users when such unfortunate events
// to do our best provide some error message to users when such unfortunate // happen.
// events happen.
// TODO: The following don't work on Windows. Specifically, sigaction, waitid // TODO: The following don't work on Windows. Specifically, sigaction, waitid
// calls ,and SIGCHLD handler. Currently, dummy implementations are provided // calls, and SIGCHLD handler. Currently, dummy implementations are provided
// for Windows. // for Windows.
#ifndef _WIN32 #ifndef _WIN32
@ -63,6 +62,7 @@ static inline void setSignalHandler(int signal, void(*handler)(int, siginfo_t *,
SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered in worker. " SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered in worker. "
"This might be caused by insufficient shared memory (shm).\n"); "This might be caused by insufficient shared memory (shm).\n");
SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n"); SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n");
SIGNAL_HANDLER(SIGFPE, handler_SIGFPE, "ERROR: Unexpected floating-point exception encountered in worker.\n");
// When an error happend in DataLoader methods and Python starts to exit, the // When an error happend in DataLoader methods and Python starts to exit, the
// error trace will keep the loader alive, and Python may kill the children // error trace will keep the loader alive, and Python may kill the children
@ -92,6 +92,7 @@ static PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *a
setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr); setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr);
setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr); setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr);
setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr); setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr);
setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr);
Py_RETURN_NONE; Py_RETURN_NONE;
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
@ -130,9 +131,7 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
} else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal } else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal
std::ostringstream oss; std::ostringstream oss;
oss << "DataLoader worker (pid " << worker_pid << ") is killed " oss << "DataLoader worker (pid " << worker_pid << ") is killed "
<< "by signal: " << strsignal(infop.si_status) << ". " << "by signal: " << strsignal(infop.si_status) << ". ";
<< "Details are lost due to multiprocessing. Rerunning with "
<< "num_workers=0 may give better error trace.";
// This is necessary. Otherwise, the runtime error will kill the other // This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again. // workers, and trigger this again.
pid_set->clear(); pid_set->clear();

View File

@ -26,10 +26,20 @@ else:
import queue import queue
# NOTE [ Python Traceback Reference Cycle Problem ]
#
# When using sys.exc_info(), it is important to **not** store the exc_info[2],
# which is the traceback, because otherwise you will run into the traceback
# reference cycle problem, i.e., the traceback holding reference to the frame,
# and the frame (which holds reference to all the object in its temporary scope)
# holding reference the traceback.
class ExceptionWrapper(object): class ExceptionWrapper(object):
r"""Wraps an exception plus traceback to communicate across threads""" r"""Wraps an exception plus traceback to communicate across threads"""
def __init__(self, exc_info): def __init__(self, exc_info):
# It is important that we don't store exc_info, see
# NOTE [ Python Traceback Reference Cycle Problem ]
self.exc_type = exc_info[0] self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info)) self.exc_msg = "".join(traceback.format_exception(*exc_info))
@ -37,7 +47,11 @@ class ExceptionWrapper(object):
_use_shared_memory = False _use_shared_memory = False
r"""Whether to use shared memory in default_collate""" r"""Whether to use shared memory in default_collate"""
MANAGER_STATUS_CHECK_INTERVAL = 5.0 MP_STATUS_CHECK_INTERVAL = 5.0
r"""Interval (in seconds) to check status of processes to avoid hanging in
multiprocessing data loading. This is mainly used in getting data from
another process, in which case we need to periodically check whether the
sender is alive to prevent hanging."""
if IS_WINDOWS: if IS_WINDOWS:
# On Windows, the parent ID of the worker process remains unchanged when the manager process # On Windows, the parent ID of the worker process remains unchanged when the manager process
@ -60,19 +74,29 @@ if IS_WINDOWS:
if not self.manager_handle: if not self.manager_handle:
raise ctypes.WinError(ctypes.get_last_error()) raise ctypes.WinError(ctypes.get_last_error())
self.manager_dead = False
def is_alive(self): def is_alive(self):
if not self.manager_dead:
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0 self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
return not self.manager_dead
else: else:
class ManagerWatchdog(object): class ManagerWatchdog(object):
def __init__(self): def __init__(self):
self.manager_pid = os.getppid() self.manager_pid = os.getppid()
self.manager_dead = False
def is_alive(self): def is_alive(self):
return os.getppid() == self.manager_pid if not self.manager_dead:
self.manager_dead = os.getppid() != self.manager_pid
return not self.manager_dead
def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id): def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
try: try:
global _use_shared_memory global _use_shared_memory
_use_shared_memory = True _use_shared_memory = True
@ -87,9 +111,6 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
random.seed(seed) random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
# Do not wait for putting thread to join when this worker exits.
# Otherwise, this worker may always be waiting to put and doesn't check
# index_queue and done_event for termination signal.
data_queue.cancel_join_thread() data_queue.cancel_join_thread()
if init_fn is not None: if init_fn is not None:
@ -97,22 +118,26 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
watchdog = ManagerWatchdog() watchdog = ManagerWatchdog()
while True: while watchdog.is_alive():
try: try:
r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL) r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty: except queue.Empty:
if watchdog.is_alive() and not done_event.is_set():
continue continue
else: if r is None:
break # Received the final signal
# use done_event so that we can get faster exiting signal even if there assert done_event.is_set()
# are still indices in index_queue return
if r is None or done_event.is_set(): elif done_event.is_set():
break # Done event is set. But I haven't received the final signal
# (None) yet. I will keep continuing until get it, and skip the
# processing steps.
continue
idx, batch_indices = r idx, batch_indices = r
try: try:
samples = collate_fn([dataset[i] for i in batch_indices]) samples = collate_fn([dataset[i] for i in batch_indices])
except Exception: except Exception:
# It is important that we don't store exc_info in a variable,
# see NOTE [ Python Traceback Reference Cycle Problem ]
data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else: else:
data_queue.put((idx, samples)) data_queue.put((idx, samples))
@ -122,25 +147,33 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
pass pass
def _pin_memory_loop(in_queue, out_queue, done_event, pin_memory, device_id): def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
if pin_memory:
torch.cuda.set_device(device_id) torch.cuda.set_device(device_id)
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
while True: while True:
try: try:
r = in_queue.get() r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
except Exception: except Exception:
if done_event.is_set(): if done_event.is_set():
return # Weird things can happen when shutting down, e.g., fd being
raise # closed when tensors are shared via fds.
if r is None or done_event.is_set():
break break
if isinstance(r[1], ExceptionWrapper): raise
out_queue.put(r) if r is None:
assert done_event.is_set()
return
elif done_event.is_set():
# Haven't seen the final signal yet. Keep getting until None.
continue continue
elif isinstance(r[1], ExceptionWrapper):
out_queue.put(r)
else:
idx, batch = r idx, batch = r
try: try:
if pin_memory:
batch = pin_memory_batch(batch) batch = pin_memory_batch(batch)
except Exception: except Exception:
out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
@ -230,6 +263,8 @@ def _set_SIGCHLD_handler():
return return
previous_handler = signal.getsignal(signal.SIGCHLD) previous_handler = signal.getsignal(signal.SIGCHLD)
if not callable(previous_handler): if not callable(previous_handler):
# This doesn't catch default handler, but SIGCHLD default handler is a
# no-op.
previous_handler = None previous_handler = None
def handler(signum, frame): def handler(signum, frame):
@ -246,6 +281,207 @@ def _set_SIGCHLD_handler():
class _DataLoaderIter(object): class _DataLoaderIter(object):
r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
# NOTE [ Data Loader Multiprocessing Shutdown Logic ]
#
# Preliminary:
#
# Our data model looks like this (queues are indicated with curly brackets):
#
# main process ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {data_queue} ||
# | ||
# data output \/
#
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
# `pin_memory=False`.
#
#
# Terminating multiprocessing logic requires very careful design. In
# particular, we need to make sure that
#
# 1. The iterator gracefully exits the workers when its last reference is
# gone.
#
# In this case, the workers should be gracefully exited because the
# main process may still need to continue to run, and we want cleaning
# up code in the workers to be executed (e.g., releasing GPU memory).
# Naturally, we implement the shutdown logic in `__del__` of
# DataLoaderIterator.
#
# We delay the discussion on the logic in this case until later.
#
# 2. The iterator exits the workers when the loader process and/or worker
# processes exits unexpectedly (e.g., SIGKILL-ed).
#
# We set all workers and `pin_memory_thread` to have `daemon=True`.
#
# You may ask, why can't we make the workers non-daemonic, and
# gracefully exit using the same logic as we have in `__del__` when the
# iterator gets deleted (see 1 above)?
#
# When a process ends, it shuts the all its daemonic children down with
# a SIGTERM (instead of joining them without a timeout). Simiarly for
# threads, but by a different mechanism. This fact, together with a few
# implementation details of multiprocessing, forces us to make workers
# daemonic. All of our problems arise when a DataLoader is used in a
# subprocess, and are caused by multiprocessing code which looks more
# or less like this:
#
# try:
# your_function_using_a_dataloader()
# finally:
# multiprocessing.util._exit_function()
#
# The joining/termination mentioned above happens inside
# `_exit_function()`. Now, if `your_function_using_a_dataloader()`
# throws, the stack trace stored in the exception will prevent the
# frame which uses `DataLoaderIter` to be freed. If the frame has any
# reference to the `DataLoaderIter` (e.g., in a method of the iter),
# its `__del__`, which starts the shutdown procedure, will not be
# called. That, in turn, means that workers aren't notified. Attempting
# to join in `_exit_function` will then result in a hang.
#
# For context, `_exit_function` is also registered as an `atexit` call.
# So it is unclear to me (@ssnl) why this is needed in a finally block.
# The code dates back to 2008 and there is no comment on the original
# PEP 371 or patch https://bugs.python.org/issue3050 (containing both
# the finally block and the `atexit` registration) that explains this.
#
# Another choice is to just shutdown workers with logic in 1 above
# whenever we see an error in `next`. This isn't ideal because
# a. It prevents users from using try-catch to resume data loading.
# b. It doesn't prevent hanging if users have references to the
# iterator.
#
# 3. All processes exit if any of them die unexpectedly (e.g., error,
# fatal signals).
#
# As shown above, the workers are set as daemonic children of the main
# process. However, automatic cleaning-up of such child processes only
# happens if the parent process exits gracefully (e.g., not via fatal
# signals like SIGKILL). So we must ensure that each process will exit
# even the process that should send/receive data to/from it were
# killed, i.e.,
#
# a. A process won't hang when getting from a queue.
#
# Even with carefully designed data dependencies (i.e., a `put()`
# always corresponding to a `get()`), hanging on `get()` can still
# happen when data in queue is corrupted (e.g., due to
# `cancel_join_thread` or unexpected exit).
#
# For child exit, we register SIGCHLD handler on main process,
# which checks if any of the workers fail in the (Python) handler.
# See DataLoader.cpp.
#
# For `.get()` calls where the sender(s) is not the workers, we
# guard them with timeouts, and check the status of the sender
# when timeout happens:
# + in the workers, the `ManagerWatchdog` class checks the main
# process status.
# + if `pin_memory=True`, when getting from `pin_memory_thread`,
# check `pin_memory_thread` status periodically until `.get()`
# returns or see that `pin_memory_thread` died.
#
# b. A process won't hang when putting into a queue;
#
# We use `mp.Queue` which has a separate background thread to put
# objects from an unbounded buffer array. The background thread is
# daemonic and usually automatically joined when the process
# exits.
#
# However, in case that the receiver has ended abruptly while
# reading from the pipe, the join will hang forever. Therefore,
# for both `worker_result_queue` (worker -> main process/pin_memory_thread)
# and each `index_queue` (main process -> worker), we use
# `q.cancel_join_thread()` in sender process before any `q.put` to
# prevent this automatic join.
#
# Moreover, having all queues called `cancel_join_thread` makes
# implementing graceful shutdown logic in `__del__` much easier.
# It won't need to get from any queue, which would also need to be
# guarded by periodic status checks.
#
# Note that this may leave corrupted data in the queue, but we
# don't care about the data anyways once we are shutting down.
#
#
# Now let's get back to 1:
# how we gracefully exit the workers when the last reference to the
# iteartor is gone.
#
# To achieve this, we implement the following logic along with the design
# choices mentioned above:
#
# [worker processes]
# While loader process is alive:
# Get from index_queue.
# If got a `None`, exit.
# If get anything else,
# Check `done_event`.
# If set, continue to next iteration
# i.e., keep getting until see the `None`, then exit.
# Otherwise, process data.
# If timed out,
# No matter `done_event` is set (still need to see `None`) or not,
# must continue to next iteration .
#
# [pin_memory_thread]
# # No need to check main thread. If this thread is alive, the main loader
# # thread must be alive, because this thread is set as daemonic.
# While True:
# Get from index_queue.
# If got a `None`, exit.
# If get anything else,
# Check `done_event`.
# If set, continue to next iteration
# i.e., keep getting until see the `None`, then exit.
# Otherwise, process data.
#
# NOTE: we don't check the status of the main thread because
# 1. if the process is killed by fatal signal, `pin_memory_thread`
# ends.
# 2. in other cases, either the cleaning-up in __del__ or the
# automatic exit of daemonic thread will take care of it.
# This won't busy-wait either because `.get(timeout)` does not
# busy-wait.
#
# [main process]
# In the DataLoader Iter's `__del__`
# a. Set `done_event` (shared with `pin_memory_thread` and workers).
#
# Note: from here on, the workers & `pin_memory_thread` may exit at
# any time after they receive `None`.
#
# b. Exit `pin_memory_thread`
# i. Put `None` in `worker_result_queue`.
# ii. Join the `pin_memory_thread`.
#
# c. Exit the workers.
# i. Put `None` in each worker's `index_queue`.
# ii. Join the workers.
#
# NOTE: This has to be after (b) because it may leave corrupted data
# in `worker_result_queue`, which `pin_memory_thread` reads
# from.
#
# NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
# can be omitted
#
# NB: `done_event`s isn't strictly needed. E.g., we can just check for
# `None` from `index_queue`, but it allows us to skip wasting resources
# processing indices already in `index_queue` if we are already shutting
# down.
def __init__(self, loader): def __init__(self, loader):
self.dataset = loader.dataset self.dataset = loader.dataset
self.collate_fn = loader.collate_fn self.collate_fn = loader.collate_fn
@ -274,18 +510,19 @@ class _DataLoaderIter(object):
self.workers = [] self.workers = []
for i in range(self.num_workers): for i in range(self.num_workers):
index_queue = multiprocessing.Queue() index_queue = multiprocessing.Queue()
index_queue.cancel_join_thread()
w = multiprocessing.Process( w = multiprocessing.Process(
target=_worker_loop, target=_worker_loop,
args=(self.dataset, index_queue, args=(self.dataset, index_queue,
self.worker_result_queue, self.done_event, self.worker_result_queue, self.done_event,
self.collate_fn, base_seed + i, self.collate_fn, base_seed + i,
self.worker_init_fn, i)) self.worker_init_fn, i))
w.daemon = True # ensure that the worker exits on process exit w.daemon = True
# Process.start() actually take some time as it needs to start a # NB: Process.start() actually take some time as it needs to
# process and pass the arguments over via a pipe. Therefore, we # start a process and pass the arguments over via a pipe.
# only add a worker to self.workers list after it started, so # Therefore, we only add a worker to self.workers list after
# that we do not call .join() if program dies before it starts, # it started, so that we do not call .join() if program dies
# and __del__ tries to join it but will get: # before it starts, and __del__ tries to join but will get:
# AssertionError: can only join a started process. # AssertionError: can only join a started process.
w.start() w.start()
self.index_queues.append(index_queue) self.index_queues.append(index_queue)
@ -295,8 +532,8 @@ class _DataLoaderIter(object):
self.data_queue = queue.Queue() self.data_queue = queue.Queue()
pin_memory_thread = threading.Thread( pin_memory_thread = threading.Thread(
target=_pin_memory_loop, target=_pin_memory_loop,
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, args=(self.worker_result_queue, self.data_queue,
torch.cuda.current_device())) torch.cuda.current_device(), self.done_event))
pin_memory_thread.daemon = True pin_memory_thread.daemon = True
pin_memory_thread.start() pin_memory_thread.start()
# Similar to workers (see comment above), we only register # Similar to workers (see comment above), we only register
@ -317,11 +554,25 @@ class _DataLoaderIter(object):
return len(self.batch_sampler) return len(self.batch_sampler)
def _get_batch(self): def _get_batch(self):
# In the non-timeout case, worker exit is covered by SIGCHLD handler.
# But if `pin_memory=True`, we still need account for the possibility
# that `pin_memory_thread` dies.
if self.timeout > 0: if self.timeout > 0:
try: try:
return self.data_queue.get(timeout=self.timeout) return self.data_queue.get(timeout=self.timeout)
except queue.Empty: except queue.Empty:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
elif self.pin_memory:
while self.pin_memory_thread.is_alive():
try:
return self.data_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
else:
# while condition is false, i.e., pin_memory_thread died.
raise RuntimeError('Pin memory thread exited unexpectedly')
# In this case, `self.data_queue` is a `queue.Queue`,. But we don't
# need to call `.task_done()` because we don't use `.join()`.
else: else:
return self.data_queue.get() return self.data_queue.get()
@ -383,29 +634,46 @@ class _DataLoaderIter(object):
raise NotImplementedError("_DataLoaderIter cannot be pickled") raise NotImplementedError("_DataLoaderIter cannot be pickled")
def _shutdown_workers(self): def _shutdown_workers(self):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
# the logic of this function.
if not self.shutdown: if not self.shutdown:
self.shutdown = True self.shutdown = True
# removes pids from the C side data structure first so worker # Removes pids from the C side data structure first so worker
# termination afterwards won't trigger false positive error report. # termination afterwards won't trigger false positive error report.
if self.worker_pids_set: if self.worker_pids_set:
_remove_worker_pids(id(self)) _remove_worker_pids(id(self))
self.worker_pids_set = False self.worker_pids_set = False
self.done_event.set() self.done_event.set()
if self.pin_memory:
# Sending `None` to `pin_memory_thread` must be before # Exit `pin_memory_thread` first because exiting workers may leave
# stopping worker processes because the workers may leave # corrupted data in `worker_result_queue` which `pin_memory_thread`
# corrupted data in `worker_result_queue`, causing # reads from.
# `pin_memory_thread` unable to read and terminate properly. if hasattr(self, 'pin_memory_thread'):
# Use hasattr in case error happens before we set the attribute.
# First time do `worker_result_queue.put` in this process.
# `cancel_join_thread` in case that `pin_memory_thread` exited.
self.worker_result_queue.cancel_join_thread()
self.worker_result_queue.put(None) self.worker_result_queue.put(None)
# Workers can't be waiting to put be cause their output queue self.pin_memory_thread.join()
# is a multiprocessing.Queue and its .put is non-blocking.
# They can only be waiting to get, so we put `None` here. # Indicate that no more data will be put on this queue by the
# current process. This **must** be called after
# `pin_memory_thread` is joined because that thread shares the
# same pipe handles with this loader thread. If the handle is
# closed, Py3 will error in this case, but Py2 will just time
# out even if there is data in the queue.
self.worker_result_queue.close()
# Exit workers now.
for q in self.index_queues: for q in self.index_queues:
q.put(None) q.put(None)
# Indicate that no more data will be put on this queue by the
# current process.
q.close()
for w in self.workers: for w in self.workers:
w.join() w.join()
if hasattr(self, 'pin_memory_thread'):
self.pin_memory_thread.join()
def __del__(self): def __del__(self):
if self.num_workers > 0: if self.num_workers > 0: