Improve test_proper_exit error printing (#20166)

Summary:
This doesn't have `strace` yet. But still have `faulthandler` to print stack traces at hanging. Also part of an attempt to isolate changes from #19228 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20166

Differential Revision: D15536504

Pulled By: ezyang

fbshipit-source-id: fe6e6e2e9899f30d8167436d7bc62b42883a3356
This commit is contained in:
Tongzhou Wang
2019-05-29 07:48:27 -07:00
committed by Facebook Github Bot
parent aa42742df0
commit 1d4685c20f

View File

@ -14,7 +14,7 @@ from torch import multiprocessing as mp
from torch.utils.data import _utils, Dataset, TensorDataset, DataLoader, ConcatDataset
from torch.utils.data._utils import ExceptionWrapper, MP_STATUS_CHECK_INTERVAL
from torch.utils.data.dataset import random_split
from common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS,
from common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, PY3,
IS_PYTORCH_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm,
load_tests)
@ -200,6 +200,33 @@ class TestConcatDataset(TestCase):
self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
# takes in dummy var so this can also be used as a `worker_init_fn`
def set_faulthander_if_available(_=None):
if HAS_FAULTHANDLER:
faulthandler.enable()
if not IS_WINDOWS:
# windows does not have faulthandler.register
# chain=False prevents the default behavior of killing the process
faulthandler.register(signal.SIGUSR1, chain=False)
# Process `pid` must have called `set_faulthander_if_available`
def print_traces_of_all_threads(pid):
if HAS_FAULTHANDLER:
if not IS_WINDOWS:
# use the custom signal if available
os.kill(pid, signal.SIGUSR1)
else:
# otherwise we can still use the handler given by faulthandler.enable()
# at the cost of killing the process.
os.kill(pid, signal.SIGSEGV)
else:
# if there is no faulthandler, use SIGINT otherwise and hope for the best
os.kill(pid, signal.SIGINT)
# wait in parent process to give subprocess some time to print
time.sleep(5)
# Stores the first encountered exception in .exception.
# Inspired by https://stackoverflow.com/a/33599967
class ErrorTrackingProcess(mp.Process):
@ -215,11 +242,7 @@ class ErrorTrackingProcess(mp.Process):
self.disable_stderr = disable_stderr
def run(self):
if HAS_FAULTHANDLER:
faulthandler.enable()
if not IS_WINDOWS:
# windows does not have faulthandler.register
faulthandler.register(signal.SIGUSR1, chain=True)
set_faulthander_if_available()
if self.disable_stderr:
# Disable polluting stderr with errors that are supposed to happen.
sys.stderr = open(os.devnull, "w")
@ -233,20 +256,11 @@ class ErrorTrackingProcess(mp.Process):
def print_traces_of_all_threads(self):
assert self.is_alive(), "can only use print_traces_of_all_threads if the process is alive"
assert not self.disable_stderr, "do not disable stderr if you use print_traces_of_all_threads"
if HAS_FAULTHANDLER:
if not IS_WINDOWS:
# use the custom signal if available
os.kill(self.pid, signal.SIGUSR1)
else:
# otherwise we can still use the handler given by faulthandler.enable()
# at the cost of killing the process, so let's poll the exception first
_ = self.exception
os.kill(self.pid, signal.SIGSEGV)
else:
# if there is no faulthandler, use SIGINT otherwise and hope for the best
os.kill(self.pid, signal.SIGINT)
# wait in parent process to give subprocess some time to print
time.sleep(5)
# On platforms without `SIGUSR1`, `set_faulthander_if_available` sets
# `faulthandler.enable()`, and `print_traces_of_all_threads` may kill
# the process. So let's poll the exception first
_ = self.exception
print_traces_of_all_threads(self.pid)
@property
def exception(self):
@ -403,7 +417,8 @@ def _test_proper_exit(use_workers, pin_memory, exit_method, hold_iter_reference,
ds = TestProperExitDataset(12, worker_error_event)
loader = DataLoader(ds, batch_size=1, shuffle=False,
num_workers=num_workers, pin_memory=pin_memory)
num_workers=num_workers, pin_memory=pin_memory,
worker_init_fn=set_faulthander_if_available)
error_it = 2
if use_workers:
@ -832,48 +847,86 @@ class TestDataLoader(TestCase):
tester_setup_event),
disable_stderr=False)
loader_p.start()
loader_psutil_p = psutil.Process(loader_p.pid)
# Wait for loader process to set everything up, e.g., starting
# workers.
loader_setup_event.wait(timeout=JOIN_TIMEOUT)
if not loader_setup_event.is_set():
loader_p.print_traces_of_all_threads()
fail_msg = desc + ': loader process failed to setup within given time'
if loader_p.exception is not None:
self.fail(fail_msg + ', and had exception {}'.format(loader_p.exception))
fail_msg += ', and had exception {}'.format(loader_p.exception)
elif not loader_p.is_alive():
self.fail(fail_msg + ', and exited with code {} but had no exception'.format(loader_p.exitcode))
fail_msg += ', and exited with code {} but had no exception'.format(loader_p.exitcode)
else:
self.fail(fail_msg + ', and is still alive.')
fail_msg += ', and is still alive.'
if loader_p.is_alive():
# this may kill the process, needs to run after the above lines
loader_p.print_traces_of_all_threads()
self.fail(fail_msg)
worker_psutil_p = psutil.Process(loader_p.pid).children()
# We are certain that the workers have started now.
worker_psutil_ps = loader_psutil_p.children()
def fail(reason):
report_psutil_attrs = ['pid', 'name', 'cpu_times', 'io_counters',
'memory_full_info', 'num_ctx_switches',
'open_files', 'threads', 'status',
'nice', 'ionice']
if reason is None:
err_msg = desc
else:
err_msg = '{}: {}'.format(desc, reason)
err_msg += '\nLoader info:\n\t'
if loader_psutil_p.is_running():
err_msg += str(loader_psutil_p.as_dict(attrs=report_psutil_attrs))
# this may kill the process, needs to run after the above line
loader_p.print_traces_of_all_threads()
else:
err_msg += 'exited with code {}'.format(loader_p.exitcode)
if use_workers:
err_msg += '\nWorker(s) info:'
for idx, worker_psutil_p in enumerate(worker_psutil_ps):
err_msg += '\n\tWorker {}:\n\t\t'.format(idx)
if worker_psutil_p.is_running():
err_msg += str(worker_psutil_p.as_dict(attrs=report_psutil_attrs))
# this may kill the process, needs to run after the above line
print_traces_of_all_threads(worker_psutil_p.pid)
else:
err_msg += 'exited with unknown code'
self.fail(err_msg)
tester_setup_event.set()
try:
loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
if loader_p.is_alive():
loader_p.print_traces_of_all_threads()
fail_msg = desc + ': loader process did not terminate'
fail_reason = 'loader process did not terminate'
if loader_p.exception is not None:
self.fail(fail_msg + ', and had exception {}'.format(loader_p.exception))
fail(fail_reason + ', and had exception {}'.format(loader_p.exception))
else:
self.fail(fail_msg + ', and had no exception')
_, alive = psutil.wait_procs(worker_psutil_p, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
fail(fail_reason + ', and had no exception')
_, alive = psutil.wait_procs(worker_psutil_ps, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
if len(alive) > 0:
self.fail(desc + ': worker process (pid(s) {}) did not terminate'.format(
', '.join(str(p.pid) for p in alive)))
self.fail(get_fail_msg('worker process (pid(s) {}) did not terminate'.format(
', '.join(str(p.pid) for p in alive))))
if exit_method is None:
self.assertEqual(loader_p.exitcode, 0)
if loader_p.exitcode != 0:
fail('loader process had nonzero exitcode {}'.format(loader_p.exitcode))
else:
self.assertNotEqual(loader_p.exitcode, 0)
if loader_p.exitcode == 0:
fail('loader process had zero exitcode')
if exit_method == 'loader_error':
self.assertIsInstance(loader_p.exception, RuntimeError, desc)
self.assertIn('Loader error', str(loader_p.exception), desc)
if not isinstance(loader_p.exception, RuntimeError) or \
'Loader error' not in str(loader_p.exception):
fail('loader process did not raise expected exception, but had {}'.format(
loader_p.exception))
elif exit_method == 'worker_kill':
if isinstance(loader_p.exception, RuntimeError):
self.assertIn('DataLoader worker (pid', str(loader_p.exception), desc)
elif isinstance(loader_p.exception, ConnectionRefusedError):
if 'DataLoader worker (pid' not in str(loader_p.exception):
fail('loader process did not raise expected exception, but had {}'.format(
loader_p.exception))
elif PY3 and isinstance(loader_p.exception, ConnectionRefusedError):
# Sometimes, when the worker is being killed and is freeing its
# resources, the unpickling in loader process will be met an
# a `ConnectionRefusedError` as it can not open a socket to receive
@ -882,11 +935,20 @@ class TestDataLoader(TestCase):
# handler. So we permit this as an allowed error as well.
# After all, we are happy as long as it terminates.
pass
elif not Py3 and isinstance(loader_p.exception, OSError):
# Same reasoning as the above if-block for Py2,
# where ConnectionRefusedError isn't a thing.
if loader_p.exception.errno != errno.ECONNREFUSED:
fail('loader process did not raise expected exception, but had {}'.format(
loader_p.exception))
else:
self.fail(desc)
fail('loader process did not raise expected exception, but had {}'.format(
loader_p.exception))
elif exit_method == 'worker_error':
self.assertIsInstance(loader_p.exception, RuntimeError, desc)
self.assertIn('Worker error', str(loader_p.exception), desc)
if not isinstance(loader_p.exception, RuntimeError) or \
'Worker error' not in str(loader_p.exception):
fail('loader process did not raise expected exception, but had {}'.format(
loader_p.exception))
finally:
loader_p.terminate()