mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
DataLoader: add error detection for worker_init_fn (#20150)
Summary: This is an attempt to isolate unrelated changes from #19228 for easier review. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20150 Differential Revision: D15314891 Pulled By: ezyang fbshipit-source-id: 8c429747ba83ad5aca4cdd8f8086bcf65a326921
This commit is contained in:
committed by
Facebook Github Bot
parent
163f0e182c
commit
f496ea36b2
@ -454,6 +454,10 @@ def _test_proper_exit(use_workers, pin_memory, exit_method, hold_iter_reference,
|
|||||||
def init_fn(worker_id):
|
def init_fn(worker_id):
|
||||||
torch.manual_seed(12345)
|
torch.manual_seed(12345)
|
||||||
|
|
||||||
|
# used with test_error_in_init
|
||||||
|
def error_worker_init_fn(_):
|
||||||
|
raise RuntimeError("Error in worker_init_fn")
|
||||||
|
|
||||||
|
|
||||||
class TestDataLoader(TestCase):
|
class TestDataLoader(TestCase):
|
||||||
|
|
||||||
@ -509,6 +513,11 @@ class TestDataLoader(TestCase):
|
|||||||
|
|
||||||
self.assertRaises(ValueError, fn)
|
self.assertRaises(ValueError, fn)
|
||||||
|
|
||||||
|
def test_error_in_init(self):
|
||||||
|
loader = DataLoader(self.dataset, num_workers=2, worker_init_fn=error_worker_init_fn)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'Error in worker_init_fn'):
|
||||||
|
list(iter(loader))
|
||||||
|
|
||||||
def test_sequential(self):
|
def test_sequential(self):
|
||||||
self._test_sequential(DataLoader(self.dataset))
|
self._test_sequential(DataLoader(self.dataset))
|
||||||
|
|
||||||
|
@ -75,8 +75,13 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
|
|||||||
|
|
||||||
data_queue.cancel_join_thread()
|
data_queue.cancel_join_thread()
|
||||||
|
|
||||||
|
init_exception = None
|
||||||
|
|
||||||
if init_fn is not None:
|
if init_fn is not None:
|
||||||
init_fn(worker_id)
|
try:
|
||||||
|
init_fn(worker_id)
|
||||||
|
except Exception:
|
||||||
|
init_exception = ExceptionWrapper(sys.exc_info())
|
||||||
|
|
||||||
watchdog = ManagerWatchdog()
|
watchdog = ManagerWatchdog()
|
||||||
|
|
||||||
@ -96,7 +101,11 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
|
|||||||
continue
|
continue
|
||||||
idx, batch_indices = r
|
idx, batch_indices = r
|
||||||
try:
|
try:
|
||||||
samples = collate_fn([dataset[i] for i in batch_indices])
|
if init_exception is not None:
|
||||||
|
samples = init_exception
|
||||||
|
init_exception = None
|
||||||
|
else:
|
||||||
|
samples = collate_fn([dataset[i] for i in batch_indices])
|
||||||
except Exception:
|
except Exception:
|
||||||
# It is important that we don't store exc_info in a variable,
|
# It is important that we don't store exc_info in a variable,
|
||||||
# see NOTE [ Python Traceback Reference Cycle Problem ]
|
# see NOTE [ Python Traceback Reference Cycle Problem ]
|
||||||
|
Reference in New Issue
Block a user