mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Prevent hanging in data loader altogether
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11985 Differential Revision: D10202374 Pulled By: SsnL fbshipit-source-id: 1ab1a07185f78a104f9b05930a87ef5a32f431e4
This commit is contained in:
committed by
Facebook Github Bot
parent
1a0d82e4f4
commit
11c31aef04
@ -91,7 +91,8 @@ TEST_MKL = torch.backends.mkl.is_available()
|
||||
# TODO: allow Py2 when librosa 0.6.2 releases
|
||||
TEST_LIBROSA = _check_module_exists('librosa') and PY3
|
||||
|
||||
NO_MULTIPROCESSING_SPAWN = os.environ.get('NO_MULTIPROCESSING_SPAWN', '0') == '1'
|
||||
# Python 2.7 doesn't have spawn
|
||||
NO_MULTIPROCESSING_SPAWN = os.environ.get('NO_MULTIPROCESSING_SPAWN', '0') == '1' or sys.version_info[0] == 2
|
||||
TEST_WITH_ASAN = os.getenv('PYTORCH_TEST_WITH_ASAN', '0') == '1'
|
||||
TEST_WITH_UBSAN = os.getenv('PYTORCH_TEST_WITH_UBSAN', '0') == '1'
|
||||
TEST_WITH_ROCM = os.getenv('PYTORCH_TEST_WITH_ROCM', '0') == '1'
|
||||
|
@ -5,25 +5,25 @@ import os
|
||||
import ctypes
|
||||
import signal
|
||||
import torch
|
||||
import gc
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
import subprocess
|
||||
import itertools
|
||||
from torch import multiprocessing as mp
|
||||
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
|
||||
from torch.utils.data.dataset import random_split
|
||||
from torch.utils.data.dataloader import default_collate, ExceptionWrapper, MANAGER_STATUS_CHECK_INTERVAL
|
||||
from torch.utils.data.dataloader import default_collate, ExceptionWrapper, MP_STATUS_CHECK_INTERVAL
|
||||
from common import TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm
|
||||
|
||||
# We cannot import TEST_CUDA from common_nn here, because if we do that,
|
||||
# the TEST_CUDNN line from common_nn will be executed multiple times
|
||||
# We cannot import TEST_CUDA from common_cuda here, because if we do that,
|
||||
# the TEST_CUDNN line from common_cuda will be executed multiple times
|
||||
# as well during the execution of this test suite, and it will cause
|
||||
# CUDA OOM error on Windows.
|
||||
TEST_CUDA = torch.cuda.is_available()
|
||||
|
||||
# We need spawn start method for test_manager_unclean_exit, but
|
||||
# Python 2.7 doesn't allow it.
|
||||
if sys.version_info[0] == 3:
|
||||
if not NO_MULTIPROCESSING_SPAWN:
|
||||
# Get a multiprocessing context because some test / third party library will
|
||||
# set start_method when imported, and setting again triggers RuntimeError.
|
||||
mp = mp.get_context(method='spawn')
|
||||
@ -149,15 +149,12 @@ class ErrorTrackingProcess(mp.Process):
|
||||
self._exception = None
|
||||
|
||||
def run(self):
|
||||
# Disable stderr printing from os level, and make workers not printing
|
||||
# to stderr.
|
||||
# Can't use sys.stderr.close, otherwise Python `raise` will error with
|
||||
# ValueError: I/O operation on closed file.
|
||||
os.close(sys.stderr.fileno())
|
||||
# Disable polluting stderr with errors that are supposed to happen.
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
try:
|
||||
super(ErrorTrackingProcess, self).run()
|
||||
self._cconn.send(None)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
self._cconn.send(ExceptionWrapper(sys.exc_info()))
|
||||
raise
|
||||
|
||||
@ -259,12 +256,94 @@ def _test_timeout():
|
||||
_ = next(iter(dataloader))
|
||||
|
||||
|
||||
def _test_timeout_pin_memory():
|
||||
dataset = SleepDataset(10, 3)
|
||||
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1, pin_memory=True)
|
||||
_ = next(iter(dataloader))
|
||||
|
||||
|
||||
def disable_stderr(worker_id):
|
||||
r"""
|
||||
Avoids printing "ERROR: Unexpected segmentation fault encountered in worker."
|
||||
from workers. Since worker signal handler prints with low-level write(),
|
||||
this has to be done on OS level via dup.
|
||||
|
||||
This is used as worker_init_fn for test_segfault.
|
||||
"""
|
||||
sys.stderr.flush() # flush library buffers that dup2 knows nothing about
|
||||
# Can't use a with-block because otherwise the fd will be closed when this
|
||||
# function ends.
|
||||
devnull = open(os.devnull, 'w')
|
||||
os.dup2(devnull.fileno(), sys.stderr.fileno())
|
||||
|
||||
|
||||
def _test_segfault():
|
||||
dataset = SegfaultDataset(10)
|
||||
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
|
||||
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr)
|
||||
_ = next(iter(dataloader))
|
||||
|
||||
|
||||
class TestProperExitDataset(object):
|
||||
def __init__(self, size, error_event):
|
||||
self.size = size
|
||||
self.error_event = error_event
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.error_event is not None and self.error_event.is_set():
|
||||
raise RuntimeError('Worker error')
|
||||
return torch.tensor([idx])
|
||||
|
||||
|
||||
# See TestDataLoader.test_proper_exit for usage
|
||||
def _test_proper_exit(use_workers, pin_memory, exit_method, hold_iter_reference,
|
||||
worker_pids, setup_event):
|
||||
num_workers = 2 if use_workers else 0
|
||||
|
||||
if exit_method == 'worker_error' or exit_method == 'worker_kill':
|
||||
assert use_workers is True
|
||||
|
||||
ds = TestProperExitDataset(16, setup_event if exit_method == 'worker_error' else None)
|
||||
|
||||
loader = DataLoader(ds, batch_size=2, shuffle=False,
|
||||
num_workers=num_workers, pin_memory=pin_memory)
|
||||
it = iter(loader)
|
||||
if use_workers:
|
||||
for i, w in enumerate(it.workers):
|
||||
worker_pids[i] = w.pid
|
||||
|
||||
error_it = 4
|
||||
assert len(loader) > error_it
|
||||
|
||||
def kill_pid(pid):
|
||||
if IS_WINDOWS:
|
||||
os.system('taskkill /PID ' + str(os.getpid()) + ' /F')
|
||||
else:
|
||||
os.kill(os.getpid(), signal.SIGKILL)
|
||||
|
||||
for i, _ in enumerate(it):
|
||||
if i == 0:
|
||||
if not hold_iter_reference:
|
||||
del it
|
||||
setup_event.set()
|
||||
if i == error_it:
|
||||
if exit_method == 'main_error':
|
||||
raise RuntimeError('Error')
|
||||
elif exit_method == 'main_kill':
|
||||
kill_pid(os.getpid())
|
||||
elif exit_method == 'worker_kill':
|
||||
kill_pid(worker_pids[0])
|
||||
|
||||
if not hold_iter_reference:
|
||||
# Tries to trigger the __del__ clean-up rather than the automatic
|
||||
# exiting of daemonic children. Technically it should be automatically
|
||||
# triggered, but I don't want to rely on the implementation detail of
|
||||
# Python gc.
|
||||
gc.collect()
|
||||
|
||||
|
||||
# test custom init function
|
||||
def init_fn(worker_id):
|
||||
torch.manual_seed(12345)
|
||||
@ -373,7 +452,12 @@ class TestDataLoader(TestCase):
|
||||
|
||||
@skipIfRocm
|
||||
def test_timeout(self):
|
||||
p = ErrorTrackingProcess(target=_test_timeout)
|
||||
if TEST_CUDA and not NO_MULTIPROCESSING_SPAWN:
|
||||
targets = (_test_timeout, _test_timeout_pin_memory)
|
||||
else:
|
||||
targets = (_test_timeout,)
|
||||
for target in targets:
|
||||
p = ErrorTrackingProcess(target=target)
|
||||
p.start()
|
||||
p.join(JOIN_TIMEOUT)
|
||||
try:
|
||||
@ -506,10 +590,14 @@ class TestDataLoader(TestCase):
|
||||
self._test_error(DataLoader(ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4))
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "FIXME: stuck test")
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_partial_workers(self):
|
||||
r"""Check that workers exit even if the iterator is not exhausted."""
|
||||
for pin_memory in (True, False):
|
||||
if TEST_CUDA:
|
||||
pin_memory_configs = (True, False)
|
||||
else:
|
||||
pin_memory_configs = (False,)
|
||||
|
||||
for pin_memory in pin_memory_configs:
|
||||
loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory))
|
||||
workers = loader.workers
|
||||
if pin_memory:
|
||||
@ -517,6 +605,7 @@ class TestDataLoader(TestCase):
|
||||
for i, sample in enumerate(loader):
|
||||
if i == 10:
|
||||
break
|
||||
assert i == 10
|
||||
del loader
|
||||
for w in workers:
|
||||
w.join(JOIN_TIMEOUT)
|
||||
@ -525,24 +614,6 @@ class TestDataLoader(TestCase):
|
||||
pin_memory_thread.join(JOIN_TIMEOUT)
|
||||
self.assertFalse(pin_memory_thread.is_alive())
|
||||
|
||||
@staticmethod
|
||||
def _main_process(dataset, worker_pids, main_exit_event, raise_error):
|
||||
loader = iter(DataLoader(dataset, batch_size=2, num_workers=4, pin_memory=True))
|
||||
workers = loader.workers
|
||||
for i in range(len(workers)):
|
||||
worker_pids[i] = int(workers[i].pid)
|
||||
for i, sample in enumerate(loader):
|
||||
if i == 3:
|
||||
# Simulate an exit of the manager process
|
||||
main_exit_event.set()
|
||||
if raise_error:
|
||||
raise RuntimeError('Error')
|
||||
else:
|
||||
if IS_WINDOWS:
|
||||
os.system('taskkill /PID ' + str(os.getpid()) + ' /F')
|
||||
else:
|
||||
os.kill(os.getpid(), signal.SIGKILL)
|
||||
|
||||
@staticmethod
|
||||
def _is_process_alive(pid, pname):
|
||||
# There is a chance of a terminated child process's pid being reused by a new unrelated process,
|
||||
@ -558,51 +629,94 @@ class TestDataLoader(TestCase):
|
||||
output = output.decode('utf-8')
|
||||
return pname in output
|
||||
|
||||
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
|
||||
don't support multiprocessing with spawn start method")
|
||||
@unittest.skipIf(sys.version_info[0] == 2,
|
||||
"spawn start method is not supported in Python 2, \
|
||||
but we need it for creating another process with CUDA")
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
@skipIfRocm
|
||||
def test_main_process_unclean_exit(self):
|
||||
r'''There might be ConnectionResetError or leaked semaphore warning (due to dirty process exit), \
|
||||
but they are all safe to ignore'''
|
||||
def test_proper_exit(self):
|
||||
r'''There might be ConnectionResetError or leaked semaphore warning
|
||||
(due to dirty process exit), but they are all safe to ignore'''
|
||||
|
||||
# `raise_error` controls if the main process is KILL-ed by OS or just
|
||||
# simply raises an error. Both cases are interesting because
|
||||
# 1. In case of it is KILL-ed by OS, the workers need to automatically
|
||||
# discover that their parent is dead and exit gracefully.
|
||||
# 2. In case of it raises an error itself, the parent process needs to
|
||||
# take care of exiting the worker and then exits itself gracefully.
|
||||
for raise_error in (True, False):
|
||||
worker_pids = mp.Array('i', [0] * 4)
|
||||
# TODO: test the case where the pin_memory_thread triggers an
|
||||
# error/fatal signal. I haven't found out how to properly do that.
|
||||
|
||||
main_exit_event = mp.Event()
|
||||
p = mp.Process(target=TestDataLoader._main_process,
|
||||
args=(self.dataset, worker_pids, main_exit_event, raise_error))
|
||||
p.start()
|
||||
worker_pids[-1] = p.pid
|
||||
# Array to store the worker pids.
|
||||
worker_pids = mp.Array('i', [-1 for _ in range(10)])
|
||||
|
||||
main_exit_event.wait()
|
||||
|
||||
exit_status = [False] * len(worker_pids)
|
||||
def wait_pids(pids, timeout):
|
||||
r"""Wait for all process specified in pids to exit in given timeout."""
|
||||
exit_status = [False for _ in pids]
|
||||
start_time = time.time()
|
||||
pname = 'python'
|
||||
while True:
|
||||
for i in range(len(worker_pids)):
|
||||
pid = worker_pids[i]
|
||||
for i in range(len(pids)):
|
||||
pid = pids[i]
|
||||
if not exit_status[i]:
|
||||
if not TestDataLoader._is_process_alive(pid, pname):
|
||||
exit_status[i] = True
|
||||
if all(exit_status):
|
||||
break
|
||||
else:
|
||||
if time.time() - start_time > MANAGER_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT:
|
||||
self.fail('subprocess not terminated')
|
||||
time.sleep(1)
|
||||
p.join(MANAGER_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT - (time.time() - start_time))
|
||||
self.assertFalse(p.is_alive(), 'main process not terminated')
|
||||
if time.time() - start_time > timeout:
|
||||
break
|
||||
time.sleep(0.5)
|
||||
return exit_status
|
||||
|
||||
for use_workers, pin_memory, hold_iter_reference in itertools.product([True, False], repeat=3):
|
||||
# `hold_iter_reference` specifies whether we hold a reference to the
|
||||
# iterator. This is interesting because Python3 error traces holds a
|
||||
# reference to the frames, which hold references to all the local
|
||||
# variables including the iterator, and then the iterator dtor may
|
||||
# not be called before process end. It is important to see that the
|
||||
# processes still exit in both cases.
|
||||
|
||||
if pin_memory and (not TEST_CUDA or NO_MULTIPROCESSING_SPAWN):
|
||||
# Can't use CUDA without spawn
|
||||
continue
|
||||
|
||||
# `exit_method` controls the way the loader process ends.
|
||||
# - `*_kill` means that `*` is killed by OS.
|
||||
# - `*_error` means that `*` raises an error.
|
||||
# - `None` means that no error happens.
|
||||
# In all cases, all processes should end properly.
|
||||
if use_workers:
|
||||
exit_methods = [None, 'main_error', 'main_kill', 'worker_kill', 'worker_error']
|
||||
else:
|
||||
exit_methods = [None, 'main_error', 'main_kill']
|
||||
|
||||
for exit_method in exit_methods:
|
||||
|
||||
# clear pids array first
|
||||
for i in range(len(worker_pids)):
|
||||
worker_pids[i] = -1
|
||||
|
||||
# Event that the loader process uses to signal testing process
|
||||
# that various things are setup, including that the worker pids
|
||||
# are specified in `worker_pids` array.
|
||||
setup_event = mp.Event()
|
||||
|
||||
p = ErrorTrackingProcess(target=_test_proper_exit,
|
||||
args=(use_workers, pin_memory, exit_method,
|
||||
hold_iter_reference, worker_pids, setup_event))
|
||||
p.start()
|
||||
|
||||
# Wait for loader process to set everything up, i.e., filling
|
||||
# worker pids in `worker_pids`.
|
||||
setup_event.wait(timeout=JOIN_TIMEOUT)
|
||||
self.assertTrue(setup_event.is_set(), 'loader process setup timed out')
|
||||
|
||||
pids = [pid for pid in worker_pids if pid > 0]
|
||||
|
||||
try:
|
||||
exit_status = wait_pids(pids, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
|
||||
if not all(exit_status):
|
||||
self.fail('subprocess (pid(s) {}) not terminated'.format(
|
||||
', '.join(p for p, exited in zip(pids, exit_status) if not exited)))
|
||||
p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
|
||||
self.assertFalse(p.is_alive(), 'loader process not terminated')
|
||||
if exit_method is None:
|
||||
self.assertEqual(p.exitcode, 0)
|
||||
else:
|
||||
self.assertNotEqual(p.exitcode, 0)
|
||||
finally:
|
||||
p.terminate()
|
||||
|
||||
def test_len(self):
|
||||
def check_len(dl, expected):
|
||||
|
@ -1,15 +1,14 @@
|
||||
#include "DataLoader.h"
|
||||
|
||||
// In cases like DataLoader, if a worker process die due to bus error/segfault
|
||||
// or just hang, the main process, if implemented with
|
||||
// multiprocessing.queue.SimpleQueue, will hang waiting for data. This is
|
||||
// difficult to avoid on PyTorch side as it can be caused by limited shm, or
|
||||
// other libraries users call in the workers. The following methods is an effort
|
||||
// to do our best provide some error message to users when such unfortunate
|
||||
// events happen.
|
||||
// In cases like DataLoader, if a worker process dies due to bus error/segfault
|
||||
// or just hang, the main process will hang waiting for data. This is difficult
|
||||
// to avoid on PyTorch side as it can be caused by limited shm, or other
|
||||
// libraries users call in the workers. The following methods is an effort to do
|
||||
// our best to provide some error message to users when such unfortunate events
|
||||
// happen.
|
||||
|
||||
// TODO: The following don't work on Windows. Specifically, sigaction, waitid
|
||||
// calls ,and SIGCHLD handler. Currently, dummy implementations are provided
|
||||
// calls, and SIGCHLD handler. Currently, dummy implementations are provided
|
||||
// for Windows.
|
||||
|
||||
#ifndef _WIN32
|
||||
@ -63,6 +62,7 @@ static inline void setSignalHandler(int signal, void(*handler)(int, siginfo_t *,
|
||||
SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered in worker. "
|
||||
"This might be caused by insufficient shared memory (shm).\n");
|
||||
SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n");
|
||||
SIGNAL_HANDLER(SIGFPE, handler_SIGFPE, "ERROR: Unexpected floating-point exception encountered in worker.\n");
|
||||
|
||||
// When an error happend in DataLoader methods and Python starts to exit, the
|
||||
// error trace will keep the loader alive, and Python may kill the children
|
||||
@ -92,6 +92,7 @@ static PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *a
|
||||
setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr);
|
||||
setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr);
|
||||
setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr);
|
||||
setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -130,9 +131,7 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
|
||||
} else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal
|
||||
std::ostringstream oss;
|
||||
oss << "DataLoader worker (pid " << worker_pid << ") is killed "
|
||||
<< "by signal: " << strsignal(infop.si_status) << ". "
|
||||
<< "Details are lost due to multiprocessing. Rerunning with "
|
||||
<< "num_workers=0 may give better error trace.";
|
||||
<< "by signal: " << strsignal(infop.si_status) << ". ";
|
||||
// This is necessary. Otherwise, the runtime error will kill the other
|
||||
// workers, and trigger this again.
|
||||
pid_set->clear();
|
||||
|
@ -26,10 +26,20 @@ else:
|
||||
import queue
|
||||
|
||||
|
||||
# NOTE [ Python Traceback Reference Cycle Problem ]
|
||||
#
|
||||
# When using sys.exc_info(), it is important to **not** store the exc_info[2],
|
||||
# which is the traceback, because otherwise you will run into the traceback
|
||||
# reference cycle problem, i.e., the traceback holding reference to the frame,
|
||||
# and the frame (which holds reference to all the object in its temporary scope)
|
||||
# holding reference the traceback.
|
||||
|
||||
|
||||
class ExceptionWrapper(object):
|
||||
r"""Wraps an exception plus traceback to communicate across threads"""
|
||||
|
||||
def __init__(self, exc_info):
|
||||
# It is important that we don't store exc_info, see
|
||||
# NOTE [ Python Traceback Reference Cycle Problem ]
|
||||
self.exc_type = exc_info[0]
|
||||
self.exc_msg = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
@ -37,7 +47,11 @@ class ExceptionWrapper(object):
|
||||
_use_shared_memory = False
|
||||
r"""Whether to use shared memory in default_collate"""
|
||||
|
||||
MANAGER_STATUS_CHECK_INTERVAL = 5.0
|
||||
MP_STATUS_CHECK_INTERVAL = 5.0
|
||||
r"""Interval (in seconds) to check status of processes to avoid hanging in
|
||||
multiprocessing data loading. This is mainly used in getting data from
|
||||
another process, in which case we need to periodically check whether the
|
||||
sender is alive to prevent hanging."""
|
||||
|
||||
if IS_WINDOWS:
|
||||
# On Windows, the parent ID of the worker process remains unchanged when the manager process
|
||||
@ -60,19 +74,29 @@ if IS_WINDOWS:
|
||||
if not self.manager_handle:
|
||||
raise ctypes.WinError(ctypes.get_last_error())
|
||||
|
||||
self.manager_dead = False
|
||||
|
||||
def is_alive(self):
|
||||
if not self.manager_dead:
|
||||
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
|
||||
return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0
|
||||
self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
|
||||
return not self.manager_dead
|
||||
else:
|
||||
class ManagerWatchdog(object):
|
||||
def __init__(self):
|
||||
self.manager_pid = os.getppid()
|
||||
self.manager_dead = False
|
||||
|
||||
def is_alive(self):
|
||||
return os.getppid() == self.manager_pid
|
||||
if not self.manager_dead:
|
||||
self.manager_dead = os.getppid() != self.manager_pid
|
||||
return not self.manager_dead
|
||||
|
||||
|
||||
def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
|
||||
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
||||
# logic of this function.
|
||||
|
||||
try:
|
||||
global _use_shared_memory
|
||||
_use_shared_memory = True
|
||||
@ -87,9 +111,6 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# Do not wait for putting thread to join when this worker exits.
|
||||
# Otherwise, this worker may always be waiting to put and doesn't check
|
||||
# index_queue and done_event for termination signal.
|
||||
data_queue.cancel_join_thread()
|
||||
|
||||
if init_fn is not None:
|
||||
@ -97,22 +118,26 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
|
||||
|
||||
watchdog = ManagerWatchdog()
|
||||
|
||||
while True:
|
||||
while watchdog.is_alive():
|
||||
try:
|
||||
r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
|
||||
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
||||
except queue.Empty:
|
||||
if watchdog.is_alive() and not done_event.is_set():
|
||||
continue
|
||||
else:
|
||||
break
|
||||
# use done_event so that we can get faster exiting signal even if there
|
||||
# are still indices in index_queue
|
||||
if r is None or done_event.is_set():
|
||||
break
|
||||
if r is None:
|
||||
# Received the final signal
|
||||
assert done_event.is_set()
|
||||
return
|
||||
elif done_event.is_set():
|
||||
# Done event is set. But I haven't received the final signal
|
||||
# (None) yet. I will keep continuing until get it, and skip the
|
||||
# processing steps.
|
||||
continue
|
||||
idx, batch_indices = r
|
||||
try:
|
||||
samples = collate_fn([dataset[i] for i in batch_indices])
|
||||
except Exception:
|
||||
# It is important that we don't store exc_info in a variable,
|
||||
# see NOTE [ Python Traceback Reference Cycle Problem ]
|
||||
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
|
||||
else:
|
||||
data_queue.put((idx, samples))
|
||||
@ -122,25 +147,33 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
|
||||
pass
|
||||
|
||||
|
||||
def _pin_memory_loop(in_queue, out_queue, done_event, pin_memory, device_id):
|
||||
if pin_memory:
|
||||
def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
|
||||
torch.cuda.set_device(device_id)
|
||||
|
||||
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
||||
# logic of this function.
|
||||
while True:
|
||||
try:
|
||||
r = in_queue.get()
|
||||
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception:
|
||||
if done_event.is_set():
|
||||
return
|
||||
raise
|
||||
if r is None or done_event.is_set():
|
||||
# Weird things can happen when shutting down, e.g., fd being
|
||||
# closed when tensors are shared via fds.
|
||||
break
|
||||
if isinstance(r[1], ExceptionWrapper):
|
||||
out_queue.put(r)
|
||||
raise
|
||||
if r is None:
|
||||
assert done_event.is_set()
|
||||
return
|
||||
elif done_event.is_set():
|
||||
# Haven't seen the final signal yet. Keep getting until None.
|
||||
continue
|
||||
elif isinstance(r[1], ExceptionWrapper):
|
||||
out_queue.put(r)
|
||||
else:
|
||||
idx, batch = r
|
||||
try:
|
||||
if pin_memory:
|
||||
batch = pin_memory_batch(batch)
|
||||
except Exception:
|
||||
out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
|
||||
@ -230,6 +263,8 @@ def _set_SIGCHLD_handler():
|
||||
return
|
||||
previous_handler = signal.getsignal(signal.SIGCHLD)
|
||||
if not callable(previous_handler):
|
||||
# This doesn't catch default handler, but SIGCHLD default handler is a
|
||||
# no-op.
|
||||
previous_handler = None
|
||||
|
||||
def handler(signum, frame):
|
||||
@ -246,6 +281,207 @@ def _set_SIGCHLD_handler():
|
||||
class _DataLoaderIter(object):
|
||||
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
|
||||
|
||||
# NOTE [ Data Loader Multiprocessing Shutdown Logic ]
|
||||
#
|
||||
# Preliminary:
|
||||
#
|
||||
# Our data model looks like this (queues are indicated with curly brackets):
|
||||
#
|
||||
# main process ||
|
||||
# | ||
|
||||
# {index_queue} ||
|
||||
# | ||
|
||||
# worker processes || DATA
|
||||
# | ||
|
||||
# {worker_result_queue} || FLOW
|
||||
# | ||
|
||||
# pin_memory_thread of main process || DIRECTION
|
||||
# | ||
|
||||
# {data_queue} ||
|
||||
# | ||
|
||||
# data output \/
|
||||
#
|
||||
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
|
||||
# `pin_memory=False`.
|
||||
#
|
||||
#
|
||||
# Terminating multiprocessing logic requires very careful design. In
|
||||
# particular, we need to make sure that
|
||||
#
|
||||
# 1. The iterator gracefully exits the workers when its last reference is
|
||||
# gone.
|
||||
#
|
||||
# In this case, the workers should be gracefully exited because the
|
||||
# main process may still need to continue to run, and we want cleaning
|
||||
# up code in the workers to be executed (e.g., releasing GPU memory).
|
||||
# Naturally, we implement the shutdown logic in `__del__` of
|
||||
# DataLoaderIterator.
|
||||
#
|
||||
# We delay the discussion on the logic in this case until later.
|
||||
#
|
||||
# 2. The iterator exits the workers when the loader process and/or worker
|
||||
# processes exits unexpectedly (e.g., SIGKILL-ed).
|
||||
#
|
||||
# We set all workers and `pin_memory_thread` to have `daemon=True`.
|
||||
#
|
||||
# You may ask, why can't we make the workers non-daemonic, and
|
||||
# gracefully exit using the same logic as we have in `__del__` when the
|
||||
# iterator gets deleted (see 1 above)?
|
||||
#
|
||||
# When a process ends, it shuts the all its daemonic children down with
|
||||
# a SIGTERM (instead of joining them without a timeout). Simiarly for
|
||||
# threads, but by a different mechanism. This fact, together with a few
|
||||
# implementation details of multiprocessing, forces us to make workers
|
||||
# daemonic. All of our problems arise when a DataLoader is used in a
|
||||
# subprocess, and are caused by multiprocessing code which looks more
|
||||
# or less like this:
|
||||
#
|
||||
# try:
|
||||
# your_function_using_a_dataloader()
|
||||
# finally:
|
||||
# multiprocessing.util._exit_function()
|
||||
#
|
||||
# The joining/termination mentioned above happens inside
|
||||
# `_exit_function()`. Now, if `your_function_using_a_dataloader()`
|
||||
# throws, the stack trace stored in the exception will prevent the
|
||||
# frame which uses `DataLoaderIter` to be freed. If the frame has any
|
||||
# reference to the `DataLoaderIter` (e.g., in a method of the iter),
|
||||
# its `__del__`, which starts the shutdown procedure, will not be
|
||||
# called. That, in turn, means that workers aren't notified. Attempting
|
||||
# to join in `_exit_function` will then result in a hang.
|
||||
#
|
||||
# For context, `_exit_function` is also registered as an `atexit` call.
|
||||
# So it is unclear to me (@ssnl) why this is needed in a finally block.
|
||||
# The code dates back to 2008 and there is no comment on the original
|
||||
# PEP 371 or patch https://bugs.python.org/issue3050 (containing both
|
||||
# the finally block and the `atexit` registration) that explains this.
|
||||
#
|
||||
# Another choice is to just shutdown workers with logic in 1 above
|
||||
# whenever we see an error in `next`. This isn't ideal because
|
||||
# a. It prevents users from using try-catch to resume data loading.
|
||||
# b. It doesn't prevent hanging if users have references to the
|
||||
# iterator.
|
||||
#
|
||||
# 3. All processes exit if any of them die unexpectedly (e.g., error,
|
||||
# fatal signals).
|
||||
#
|
||||
# As shown above, the workers are set as daemonic children of the main
|
||||
# process. However, automatic cleaning-up of such child processes only
|
||||
# happens if the parent process exits gracefully (e.g., not via fatal
|
||||
# signals like SIGKILL). So we must ensure that each process will exit
|
||||
# even the process that should send/receive data to/from it were
|
||||
# killed, i.e.,
|
||||
#
|
||||
# a. A process won't hang when getting from a queue.
|
||||
#
|
||||
# Even with carefully designed data dependencies (i.e., a `put()`
|
||||
# always corresponding to a `get()`), hanging on `get()` can still
|
||||
# happen when data in queue is corrupted (e.g., due to
|
||||
# `cancel_join_thread` or unexpected exit).
|
||||
#
|
||||
# For child exit, we register SIGCHLD handler on main process,
|
||||
# which checks if any of the workers fail in the (Python) handler.
|
||||
# See DataLoader.cpp.
|
||||
#
|
||||
# For `.get()` calls where the sender(s) is not the workers, we
|
||||
# guard them with timeouts, and check the status of the sender
|
||||
# when timeout happens:
|
||||
# + in the workers, the `ManagerWatchdog` class checks the main
|
||||
# process status.
|
||||
# + if `pin_memory=True`, when getting from `pin_memory_thread`,
|
||||
# check `pin_memory_thread` status periodically until `.get()`
|
||||
# returns or see that `pin_memory_thread` died.
|
||||
#
|
||||
# b. A process won't hang when putting into a queue;
|
||||
#
|
||||
# We use `mp.Queue` which has a separate background thread to put
|
||||
# objects from an unbounded buffer array. The background thread is
|
||||
# daemonic and usually automatically joined when the process
|
||||
# exits.
|
||||
#
|
||||
# However, in case that the receiver has ended abruptly while
|
||||
# reading from the pipe, the join will hang forever. Therefore,
|
||||
# for both `worker_result_queue` (worker -> main process/pin_memory_thread)
|
||||
# and each `index_queue` (main process -> worker), we use
|
||||
# `q.cancel_join_thread()` in sender process before any `q.put` to
|
||||
# prevent this automatic join.
|
||||
#
|
||||
# Moreover, having all queues called `cancel_join_thread` makes
|
||||
# implementing graceful shutdown logic in `__del__` much easier.
|
||||
# It won't need to get from any queue, which would also need to be
|
||||
# guarded by periodic status checks.
|
||||
#
|
||||
# Note that this may leave corrupted data in the queue, but we
|
||||
# don't care about the data anyways once we are shutting down.
|
||||
#
|
||||
#
|
||||
# Now let's get back to 1:
|
||||
# how we gracefully exit the workers when the last reference to the
|
||||
# iteartor is gone.
|
||||
#
|
||||
# To achieve this, we implement the following logic along with the design
|
||||
# choices mentioned above:
|
||||
#
|
||||
# [worker processes]
|
||||
# While loader process is alive:
|
||||
# Get from index_queue.
|
||||
# If got a `None`, exit.
|
||||
# If get anything else,
|
||||
# Check `done_event`.
|
||||
# If set, continue to next iteration
|
||||
# i.e., keep getting until see the `None`, then exit.
|
||||
# Otherwise, process data.
|
||||
# If timed out,
|
||||
# No matter `done_event` is set (still need to see `None`) or not,
|
||||
# must continue to next iteration .
|
||||
#
|
||||
# [pin_memory_thread]
|
||||
# # No need to check main thread. If this thread is alive, the main loader
|
||||
# # thread must be alive, because this thread is set as daemonic.
|
||||
# While True:
|
||||
# Get from index_queue.
|
||||
# If got a `None`, exit.
|
||||
# If get anything else,
|
||||
# Check `done_event`.
|
||||
# If set, continue to next iteration
|
||||
# i.e., keep getting until see the `None`, then exit.
|
||||
# Otherwise, process data.
|
||||
#
|
||||
# NOTE: we don't check the status of the main thread because
|
||||
# 1. if the process is killed by fatal signal, `pin_memory_thread`
|
||||
# ends.
|
||||
# 2. in other cases, either the cleaning-up in __del__ or the
|
||||
# automatic exit of daemonic thread will take care of it.
|
||||
# This won't busy-wait either because `.get(timeout)` does not
|
||||
# busy-wait.
|
||||
#
|
||||
# [main process]
|
||||
# In the DataLoader Iter's `__del__`
|
||||
# a. Set `done_event` (shared with `pin_memory_thread` and workers).
|
||||
#
|
||||
# Note: from here on, the workers & `pin_memory_thread` may exit at
|
||||
# any time after they receive `None`.
|
||||
#
|
||||
# b. Exit `pin_memory_thread`
|
||||
# i. Put `None` in `worker_result_queue`.
|
||||
# ii. Join the `pin_memory_thread`.
|
||||
#
|
||||
# c. Exit the workers.
|
||||
# i. Put `None` in each worker's `index_queue`.
|
||||
# ii. Join the workers.
|
||||
#
|
||||
# NOTE: This has to be after (b) because it may leave corrupted data
|
||||
# in `worker_result_queue`, which `pin_memory_thread` reads
|
||||
# from.
|
||||
#
|
||||
# NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
|
||||
# can be omitted
|
||||
#
|
||||
# NB: `done_event`s isn't strictly needed. E.g., we can just check for
|
||||
# `None` from `index_queue`, but it allows us to skip wasting resources
|
||||
# processing indices already in `index_queue` if we are already shutting
|
||||
# down.
|
||||
|
||||
def __init__(self, loader):
|
||||
self.dataset = loader.dataset
|
||||
self.collate_fn = loader.collate_fn
|
||||
@ -274,18 +510,19 @@ class _DataLoaderIter(object):
|
||||
self.workers = []
|
||||
for i in range(self.num_workers):
|
||||
index_queue = multiprocessing.Queue()
|
||||
index_queue.cancel_join_thread()
|
||||
w = multiprocessing.Process(
|
||||
target=_worker_loop,
|
||||
args=(self.dataset, index_queue,
|
||||
self.worker_result_queue, self.done_event,
|
||||
self.collate_fn, base_seed + i,
|
||||
self.worker_init_fn, i))
|
||||
w.daemon = True # ensure that the worker exits on process exit
|
||||
# Process.start() actually take some time as it needs to start a
|
||||
# process and pass the arguments over via a pipe. Therefore, we
|
||||
# only add a worker to self.workers list after it started, so
|
||||
# that we do not call .join() if program dies before it starts,
|
||||
# and __del__ tries to join it but will get:
|
||||
w.daemon = True
|
||||
# NB: Process.start() actually take some time as it needs to
|
||||
# start a process and pass the arguments over via a pipe.
|
||||
# Therefore, we only add a worker to self.workers list after
|
||||
# it started, so that we do not call .join() if program dies
|
||||
# before it starts, and __del__ tries to join but will get:
|
||||
# AssertionError: can only join a started process.
|
||||
w.start()
|
||||
self.index_queues.append(index_queue)
|
||||
@ -295,8 +532,8 @@ class _DataLoaderIter(object):
|
||||
self.data_queue = queue.Queue()
|
||||
pin_memory_thread = threading.Thread(
|
||||
target=_pin_memory_loop,
|
||||
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
|
||||
torch.cuda.current_device()))
|
||||
args=(self.worker_result_queue, self.data_queue,
|
||||
torch.cuda.current_device(), self.done_event))
|
||||
pin_memory_thread.daemon = True
|
||||
pin_memory_thread.start()
|
||||
# Similar to workers (see comment above), we only register
|
||||
@ -317,11 +554,25 @@ class _DataLoaderIter(object):
|
||||
return len(self.batch_sampler)
|
||||
|
||||
def _get_batch(self):
|
||||
# In the non-timeout case, worker exit is covered by SIGCHLD handler.
|
||||
# But if `pin_memory=True`, we still need account for the possibility
|
||||
# that `pin_memory_thread` dies.
|
||||
if self.timeout > 0:
|
||||
try:
|
||||
return self.data_queue.get(timeout=self.timeout)
|
||||
except queue.Empty:
|
||||
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
|
||||
elif self.pin_memory:
|
||||
while self.pin_memory_thread.is_alive():
|
||||
try:
|
||||
return self.data_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
||||
except queue.Empty:
|
||||
continue
|
||||
else:
|
||||
# while condition is false, i.e., pin_memory_thread died.
|
||||
raise RuntimeError('Pin memory thread exited unexpectedly')
|
||||
# In this case, `self.data_queue` is a `queue.Queue`,. But we don't
|
||||
# need to call `.task_done()` because we don't use `.join()`.
|
||||
else:
|
||||
return self.data_queue.get()
|
||||
|
||||
@ -383,29 +634,46 @@ class _DataLoaderIter(object):
|
||||
raise NotImplementedError("_DataLoaderIter cannot be pickled")
|
||||
|
||||
def _shutdown_workers(self):
|
||||
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
|
||||
# the logic of this function.
|
||||
if not self.shutdown:
|
||||
self.shutdown = True
|
||||
# removes pids from the C side data structure first so worker
|
||||
# Removes pids from the C side data structure first so worker
|
||||
# termination afterwards won't trigger false positive error report.
|
||||
if self.worker_pids_set:
|
||||
_remove_worker_pids(id(self))
|
||||
self.worker_pids_set = False
|
||||
|
||||
self.done_event.set()
|
||||
if self.pin_memory:
|
||||
# Sending `None` to `pin_memory_thread` must be before
|
||||
# stopping worker processes because the workers may leave
|
||||
# corrupted data in `worker_result_queue`, causing
|
||||
# `pin_memory_thread` unable to read and terminate properly.
|
||||
|
||||
# Exit `pin_memory_thread` first because exiting workers may leave
|
||||
# corrupted data in `worker_result_queue` which `pin_memory_thread`
|
||||
# reads from.
|
||||
if hasattr(self, 'pin_memory_thread'):
|
||||
# Use hasattr in case error happens before we set the attribute.
|
||||
# First time do `worker_result_queue.put` in this process.
|
||||
|
||||
# `cancel_join_thread` in case that `pin_memory_thread` exited.
|
||||
self.worker_result_queue.cancel_join_thread()
|
||||
self.worker_result_queue.put(None)
|
||||
# Workers can't be waiting to put be cause their output queue
|
||||
# is a multiprocessing.Queue and its .put is non-blocking.
|
||||
# They can only be waiting to get, so we put `None` here.
|
||||
self.pin_memory_thread.join()
|
||||
|
||||
# Indicate that no more data will be put on this queue by the
|
||||
# current process. This **must** be called after
|
||||
# `pin_memory_thread` is joined because that thread shares the
|
||||
# same pipe handles with this loader thread. If the handle is
|
||||
# closed, Py3 will error in this case, but Py2 will just time
|
||||
# out even if there is data in the queue.
|
||||
self.worker_result_queue.close()
|
||||
|
||||
# Exit workers now.
|
||||
for q in self.index_queues:
|
||||
q.put(None)
|
||||
# Indicate that no more data will be put on this queue by the
|
||||
# current process.
|
||||
q.close()
|
||||
for w in self.workers:
|
||||
w.join()
|
||||
if hasattr(self, 'pin_memory_thread'):
|
||||
self.pin_memory_thread.join()
|
||||
|
||||
def __del__(self):
|
||||
if self.num_workers > 0:
|
||||
|
Reference in New Issue
Block a user