DataLoader: add error detection for worker_init_fn (#20150)

Summary:
This is an attempt to isolate unrelated changes from #19228 for easier review.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20150

Differential Revision: D15314891

Pulled By: ezyang

fbshipit-source-id: 8c429747ba83ad5aca4cdd8f8086bcf65a326921
This commit is contained in:
Tongzhou Wang
2019-05-12 18:26:01 -07:00
committed by Facebook Github Bot
parent 163f0e182c
commit f496ea36b2
2 changed files with 20 additions and 2 deletions

View File

@ -454,6 +454,10 @@ def _test_proper_exit(use_workers, pin_memory, exit_method, hold_iter_reference,
def init_fn(worker_id):
torch.manual_seed(12345)
# used with test_error_in_init
def error_worker_init_fn(_):
raise RuntimeError("Error in worker_init_fn")
class TestDataLoader(TestCase):
@ -509,6 +513,11 @@ class TestDataLoader(TestCase):
self.assertRaises(ValueError, fn)
def test_error_in_init(self):
loader = DataLoader(self.dataset, num_workers=2, worker_init_fn=error_worker_init_fn)
with self.assertRaisesRegex(RuntimeError, 'Error in worker_init_fn'):
list(iter(loader))
def test_sequential(self):
self._test_sequential(DataLoader(self.dataset))

View File

@ -75,8 +75,13 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
data_queue.cancel_join_thread()
init_exception = None
if init_fn is not None:
init_fn(worker_id)
try:
init_fn(worker_id)
except Exception:
init_exception = ExceptionWrapper(sys.exc_info())
watchdog = ManagerWatchdog()
@ -96,7 +101,11 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
continue
idx, batch_indices = r
try:
samples = collate_fn([dataset[i] for i in batch_indices])
if init_exception is not None:
samples = init_exception
init_exception = None
else:
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
# It is important that we don't store exc_info in a variable,
# see NOTE [ Python Traceback Reference Cycle Problem ]