mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 13:34:57 +08:00
Improve test_proper_exit error printing (#20166)
Summary: This doesn't have `strace` yet. But still have `faulthandler` to print stack traces at hanging. Also part of an attempt to isolate changes from #19228 . Pull Request resolved: https://github.com/pytorch/pytorch/pull/20166 Differential Revision: D15536504 Pulled By: ezyang fbshipit-source-id: fe6e6e2e9899f30d8167436d7bc62b42883a3356
This commit is contained in:
committed by
Facebook Github Bot
parent
aa42742df0
commit
1d4685c20f
@ -14,7 +14,7 @@ from torch import multiprocessing as mp
|
||||
from torch.utils.data import _utils, Dataset, TensorDataset, DataLoader, ConcatDataset
|
||||
from torch.utils.data._utils import ExceptionWrapper, MP_STATUS_CHECK_INTERVAL
|
||||
from torch.utils.data.dataset import random_split
|
||||
from common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS,
|
||||
from common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, PY3,
|
||||
IS_PYTORCH_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm,
|
||||
load_tests)
|
||||
|
||||
@ -200,6 +200,33 @@ class TestConcatDataset(TestCase):
|
||||
self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
|
||||
|
||||
|
||||
# takes in dummy var so this can also be used as a `worker_init_fn`
|
||||
def set_faulthander_if_available(_=None):
|
||||
if HAS_FAULTHANDLER:
|
||||
faulthandler.enable()
|
||||
if not IS_WINDOWS:
|
||||
# windows does not have faulthandler.register
|
||||
# chain=False prevents the default behavior of killing the process
|
||||
faulthandler.register(signal.SIGUSR1, chain=False)
|
||||
|
||||
|
||||
# Process `pid` must have called `set_faulthander_if_available`
|
||||
def print_traces_of_all_threads(pid):
|
||||
if HAS_FAULTHANDLER:
|
||||
if not IS_WINDOWS:
|
||||
# use the custom signal if available
|
||||
os.kill(pid, signal.SIGUSR1)
|
||||
else:
|
||||
# otherwise we can still use the handler given by faulthandler.enable()
|
||||
# at the cost of killing the process.
|
||||
os.kill(pid, signal.SIGSEGV)
|
||||
else:
|
||||
# if there is no faulthandler, use SIGINT otherwise and hope for the best
|
||||
os.kill(pid, signal.SIGINT)
|
||||
# wait in parent process to give subprocess some time to print
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
# Stores the first encountered exception in .exception.
|
||||
# Inspired by https://stackoverflow.com/a/33599967
|
||||
class ErrorTrackingProcess(mp.Process):
|
||||
@ -215,11 +242,7 @@ class ErrorTrackingProcess(mp.Process):
|
||||
self.disable_stderr = disable_stderr
|
||||
|
||||
def run(self):
|
||||
if HAS_FAULTHANDLER:
|
||||
faulthandler.enable()
|
||||
if not IS_WINDOWS:
|
||||
# windows does not have faulthandler.register
|
||||
faulthandler.register(signal.SIGUSR1, chain=True)
|
||||
set_faulthander_if_available()
|
||||
if self.disable_stderr:
|
||||
# Disable polluting stderr with errors that are supposed to happen.
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
@ -233,20 +256,11 @@ class ErrorTrackingProcess(mp.Process):
|
||||
def print_traces_of_all_threads(self):
|
||||
assert self.is_alive(), "can only use print_traces_of_all_threads if the process is alive"
|
||||
assert not self.disable_stderr, "do not disable stderr if you use print_traces_of_all_threads"
|
||||
if HAS_FAULTHANDLER:
|
||||
if not IS_WINDOWS:
|
||||
# use the custom signal if available
|
||||
os.kill(self.pid, signal.SIGUSR1)
|
||||
else:
|
||||
# otherwise we can still use the handler given by faulthandler.enable()
|
||||
# at the cost of killing the process, so let's poll the exception first
|
||||
_ = self.exception
|
||||
os.kill(self.pid, signal.SIGSEGV)
|
||||
else:
|
||||
# if there is no faulthandler, use SIGINT otherwise and hope for the best
|
||||
os.kill(self.pid, signal.SIGINT)
|
||||
# wait in parent process to give subprocess some time to print
|
||||
time.sleep(5)
|
||||
# On platforms without `SIGUSR1`, `set_faulthander_if_available` sets
|
||||
# `faulthandler.enable()`, and `print_traces_of_all_threads` may kill
|
||||
# the process. So let's poll the exception first
|
||||
_ = self.exception
|
||||
print_traces_of_all_threads(self.pid)
|
||||
|
||||
@property
|
||||
def exception(self):
|
||||
@ -403,7 +417,8 @@ def _test_proper_exit(use_workers, pin_memory, exit_method, hold_iter_reference,
|
||||
ds = TestProperExitDataset(12, worker_error_event)
|
||||
|
||||
loader = DataLoader(ds, batch_size=1, shuffle=False,
|
||||
num_workers=num_workers, pin_memory=pin_memory)
|
||||
num_workers=num_workers, pin_memory=pin_memory,
|
||||
worker_init_fn=set_faulthander_if_available)
|
||||
error_it = 2
|
||||
|
||||
if use_workers:
|
||||
@ -832,48 +847,86 @@ class TestDataLoader(TestCase):
|
||||
tester_setup_event),
|
||||
disable_stderr=False)
|
||||
loader_p.start()
|
||||
loader_psutil_p = psutil.Process(loader_p.pid)
|
||||
|
||||
# Wait for loader process to set everything up, e.g., starting
|
||||
# workers.
|
||||
loader_setup_event.wait(timeout=JOIN_TIMEOUT)
|
||||
if not loader_setup_event.is_set():
|
||||
loader_p.print_traces_of_all_threads()
|
||||
fail_msg = desc + ': loader process failed to setup within given time'
|
||||
if loader_p.exception is not None:
|
||||
self.fail(fail_msg + ', and had exception {}'.format(loader_p.exception))
|
||||
fail_msg += ', and had exception {}'.format(loader_p.exception)
|
||||
elif not loader_p.is_alive():
|
||||
self.fail(fail_msg + ', and exited with code {} but had no exception'.format(loader_p.exitcode))
|
||||
fail_msg += ', and exited with code {} but had no exception'.format(loader_p.exitcode)
|
||||
else:
|
||||
self.fail(fail_msg + ', and is still alive.')
|
||||
fail_msg += ', and is still alive.'
|
||||
if loader_p.is_alive():
|
||||
# this may kill the process, needs to run after the above lines
|
||||
loader_p.print_traces_of_all_threads()
|
||||
self.fail(fail_msg)
|
||||
|
||||
worker_psutil_p = psutil.Process(loader_p.pid).children()
|
||||
# We are certain that the workers have started now.
|
||||
worker_psutil_ps = loader_psutil_p.children()
|
||||
|
||||
def fail(reason):
|
||||
report_psutil_attrs = ['pid', 'name', 'cpu_times', 'io_counters',
|
||||
'memory_full_info', 'num_ctx_switches',
|
||||
'open_files', 'threads', 'status',
|
||||
'nice', 'ionice']
|
||||
if reason is None:
|
||||
err_msg = desc
|
||||
else:
|
||||
err_msg = '{}: {}'.format(desc, reason)
|
||||
err_msg += '\nLoader info:\n\t'
|
||||
if loader_psutil_p.is_running():
|
||||
err_msg += str(loader_psutil_p.as_dict(attrs=report_psutil_attrs))
|
||||
# this may kill the process, needs to run after the above line
|
||||
loader_p.print_traces_of_all_threads()
|
||||
else:
|
||||
err_msg += 'exited with code {}'.format(loader_p.exitcode)
|
||||
if use_workers:
|
||||
err_msg += '\nWorker(s) info:'
|
||||
for idx, worker_psutil_p in enumerate(worker_psutil_ps):
|
||||
err_msg += '\n\tWorker {}:\n\t\t'.format(idx)
|
||||
if worker_psutil_p.is_running():
|
||||
err_msg += str(worker_psutil_p.as_dict(attrs=report_psutil_attrs))
|
||||
# this may kill the process, needs to run after the above line
|
||||
print_traces_of_all_threads(worker_psutil_p.pid)
|
||||
else:
|
||||
err_msg += 'exited with unknown code'
|
||||
self.fail(err_msg)
|
||||
|
||||
tester_setup_event.set()
|
||||
|
||||
try:
|
||||
loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
|
||||
if loader_p.is_alive():
|
||||
loader_p.print_traces_of_all_threads()
|
||||
fail_msg = desc + ': loader process did not terminate'
|
||||
fail_reason = 'loader process did not terminate'
|
||||
if loader_p.exception is not None:
|
||||
self.fail(fail_msg + ', and had exception {}'.format(loader_p.exception))
|
||||
fail(fail_reason + ', and had exception {}'.format(loader_p.exception))
|
||||
else:
|
||||
self.fail(fail_msg + ', and had no exception')
|
||||
_, alive = psutil.wait_procs(worker_psutil_p, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
|
||||
fail(fail_reason + ', and had no exception')
|
||||
_, alive = psutil.wait_procs(worker_psutil_ps, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
|
||||
if len(alive) > 0:
|
||||
self.fail(desc + ': worker process (pid(s) {}) did not terminate'.format(
|
||||
', '.join(str(p.pid) for p in alive)))
|
||||
self.fail(get_fail_msg('worker process (pid(s) {}) did not terminate'.format(
|
||||
', '.join(str(p.pid) for p in alive))))
|
||||
if exit_method is None:
|
||||
self.assertEqual(loader_p.exitcode, 0)
|
||||
if loader_p.exitcode != 0:
|
||||
fail('loader process had nonzero exitcode {}'.format(loader_p.exitcode))
|
||||
else:
|
||||
self.assertNotEqual(loader_p.exitcode, 0)
|
||||
if loader_p.exitcode == 0:
|
||||
fail('loader process had zero exitcode')
|
||||
if exit_method == 'loader_error':
|
||||
self.assertIsInstance(loader_p.exception, RuntimeError, desc)
|
||||
self.assertIn('Loader error', str(loader_p.exception), desc)
|
||||
if not isinstance(loader_p.exception, RuntimeError) or \
|
||||
'Loader error' not in str(loader_p.exception):
|
||||
fail('loader process did not raise expected exception, but had {}'.format(
|
||||
loader_p.exception))
|
||||
elif exit_method == 'worker_kill':
|
||||
if isinstance(loader_p.exception, RuntimeError):
|
||||
self.assertIn('DataLoader worker (pid', str(loader_p.exception), desc)
|
||||
elif isinstance(loader_p.exception, ConnectionRefusedError):
|
||||
if 'DataLoader worker (pid' not in str(loader_p.exception):
|
||||
fail('loader process did not raise expected exception, but had {}'.format(
|
||||
loader_p.exception))
|
||||
elif PY3 and isinstance(loader_p.exception, ConnectionRefusedError):
|
||||
# Sometimes, when the worker is being killed and is freeing its
|
||||
# resources, the unpickling in loader process will be met an
|
||||
# a `ConnectionRefusedError` as it can not open a socket to receive
|
||||
@ -882,11 +935,20 @@ class TestDataLoader(TestCase):
|
||||
# handler. So we permit this as an allowed error as well.
|
||||
# After all, we are happy as long as it terminates.
|
||||
pass
|
||||
elif not Py3 and isinstance(loader_p.exception, OSError):
|
||||
# Same reasoning as the above if-block for Py2,
|
||||
# where ConnectionRefusedError isn't a thing.
|
||||
if loader_p.exception.errno != errno.ECONNREFUSED:
|
||||
fail('loader process did not raise expected exception, but had {}'.format(
|
||||
loader_p.exception))
|
||||
else:
|
||||
self.fail(desc)
|
||||
fail('loader process did not raise expected exception, but had {}'.format(
|
||||
loader_p.exception))
|
||||
elif exit_method == 'worker_error':
|
||||
self.assertIsInstance(loader_p.exception, RuntimeError, desc)
|
||||
self.assertIn('Worker error', str(loader_p.exception), desc)
|
||||
if not isinstance(loader_p.exception, RuntimeError) or \
|
||||
'Worker error' not in str(loader_p.exception):
|
||||
fail('loader process did not raise expected exception, but had {}'.format(
|
||||
loader_p.exception))
|
||||
finally:
|
||||
loader_p.terminate()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user