mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
DataLoader: add error detection for worker_init_fn (#20150)
Summary: This is an attempt to isolate unrelated changes from #19228 for easier review. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20150 Differential Revision: D15314891 Pulled By: ezyang fbshipit-source-id: 8c429747ba83ad5aca4cdd8f8086bcf65a326921
This commit is contained in:
committed by
Facebook Github Bot
parent
163f0e182c
commit
f496ea36b2
@ -454,6 +454,10 @@ def _test_proper_exit(use_workers, pin_memory, exit_method, hold_iter_reference,
|
||||
def init_fn(worker_id):
|
||||
torch.manual_seed(12345)
|
||||
|
||||
# used with test_error_in_init
|
||||
def error_worker_init_fn(_):
|
||||
raise RuntimeError("Error in worker_init_fn")
|
||||
|
||||
|
||||
class TestDataLoader(TestCase):
|
||||
|
||||
@ -509,6 +513,11 @@ class TestDataLoader(TestCase):
|
||||
|
||||
self.assertRaises(ValueError, fn)
|
||||
|
||||
def test_error_in_init(self):
|
||||
loader = DataLoader(self.dataset, num_workers=2, worker_init_fn=error_worker_init_fn)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Error in worker_init_fn'):
|
||||
list(iter(loader))
|
||||
|
||||
def test_sequential(self):
|
||||
self._test_sequential(DataLoader(self.dataset))
|
||||
|
||||
|
@ -75,8 +75,13 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
|
||||
|
||||
data_queue.cancel_join_thread()
|
||||
|
||||
init_exception = None
|
||||
|
||||
if init_fn is not None:
|
||||
init_fn(worker_id)
|
||||
try:
|
||||
init_fn(worker_id)
|
||||
except Exception:
|
||||
init_exception = ExceptionWrapper(sys.exc_info())
|
||||
|
||||
watchdog = ManagerWatchdog()
|
||||
|
||||
@ -96,7 +101,11 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
|
||||
continue
|
||||
idx, batch_indices = r
|
||||
try:
|
||||
samples = collate_fn([dataset[i] for i in batch_indices])
|
||||
if init_exception is not None:
|
||||
samples = init_exception
|
||||
init_exception = None
|
||||
else:
|
||||
samples = collate_fn([dataset[i] for i in batch_indices])
|
||||
except Exception:
|
||||
# It is important that we don't store exc_info in a variable,
|
||||
# see NOTE [ Python Traceback Reference Cycle Problem ]
|
||||
|
Reference in New Issue
Block a user