mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Robustify torch.multiprocessing.spawn error reporting to be less deadlock prone (#114688)
multiprocessing.Queue relies on, among other things, background threads to send messages between processes. This works in the happy path but can cause issues if a process is exiting by bypassing atexit handlers or crashing because the writer to the Queue can terminate while the reader is blocked reading the queue. The reader sees the queue as non-empty yet even with a timeout will actually block forever. An example of a Queue deadlock is here: https://gist.github.com/chipturner/342f72341f087737befe9df84d0e41ce Since the error reporting case here is a simple one-shot message from the dying child to the parent, we can just use a file-based rendezvous. This eliminates the deadlock when a large traceback is still being flushed to the network when a child exits. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114688 Approved by: https://github.com/suo, https://github.com/yifuwang
This commit is contained in:
committed by
PyTorch MergeBot
parent
2962271f58
commit
2ed47fecc5
@ -1,8 +1,11 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import multiprocessing.connection
|
||||
import os
|
||||
import pickle
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import warnings
|
||||
from typing import Optional
|
||||
@ -61,7 +64,7 @@ class ProcessExitedException(ProcessException):
|
||||
)
|
||||
|
||||
|
||||
def _wrap(fn, i, args, error_queue):
|
||||
def _wrap(fn, i, args, error_file):
|
||||
# prctl(2) is a Linux specific system call.
|
||||
# On other systems the following function call has no effect.
|
||||
# This is set to ensure that non-daemonic child processes can
|
||||
@ -76,13 +79,14 @@ def _wrap(fn, i, args, error_queue):
|
||||
# Propagate exception to parent process, keeping original traceback
|
||||
import traceback
|
||||
|
||||
error_queue.put(traceback.format_exc())
|
||||
with open(error_file, "wb") as fh:
|
||||
pickle.dump(traceback.format_exc(), fh)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class ProcessContext:
|
||||
def __init__(self, processes, error_queues):
|
||||
self.error_queues = error_queues
|
||||
def __init__(self, processes, error_files):
|
||||
self.error_files = error_files
|
||||
self.processes = processes
|
||||
self.sentinels = {
|
||||
process.sentinel: index for index, process in enumerate(processes)
|
||||
@ -153,9 +157,9 @@ class ProcessContext:
|
||||
process.kill()
|
||||
process.join()
|
||||
|
||||
# There won't be an error on the queue if the process crashed.
|
||||
# The file will only be created if the process crashed.
|
||||
failed_process = self.processes[error_index]
|
||||
if self.error_queues[error_index].empty():
|
||||
if not os.access(self.error_files[error_index], os.R_OK):
|
||||
exitcode = self.processes[error_index].exitcode
|
||||
if exitcode < 0:
|
||||
try:
|
||||
@ -177,16 +181,17 @@ class ProcessContext:
|
||||
exit_code=exitcode,
|
||||
)
|
||||
|
||||
original_trace = self.error_queues[error_index].get()
|
||||
with open(self.error_files[error_index], "rb") as fh:
|
||||
original_trace = pickle.load(fh)
|
||||
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
|
||||
msg += original_trace
|
||||
raise ProcessRaisedException(msg, error_index, failed_process.pid)
|
||||
|
||||
|
||||
class SpawnContext(ProcessContext):
|
||||
def __init__(self, processes, error_queues):
|
||||
def __init__(self, processes, error_files):
|
||||
warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.")
|
||||
super().__init__(processes, error_queues)
|
||||
super().__init__(processes, error_files)
|
||||
|
||||
|
||||
# Note: [start_processes]
|
||||
@ -201,20 +206,30 @@ def start_processes(
|
||||
fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"
|
||||
):
|
||||
mp = multiprocessing.get_context(start_method)
|
||||
error_queues = []
|
||||
error_files = []
|
||||
processes = []
|
||||
for i in range(nprocs):
|
||||
error_queue = mp.SimpleQueue()
|
||||
# Each process is assigned a file to write tracebacks to. We
|
||||
# use the file being non-empty to indicate an exception
|
||||
# occurred (vs an expected shutdown). Note: this previously
|
||||
# used a multiprocessing.Queue but that can be prone to
|
||||
# deadlocks, so we went with a simpler solution for a one-shot
|
||||
# message between processes.
|
||||
tf = tempfile.NamedTemporaryFile(
|
||||
prefix="pytorch-errorfile-", suffix=".pickle", delete=False
|
||||
)
|
||||
tf.close()
|
||||
os.unlink(tf.name)
|
||||
process = mp.Process(
|
||||
target=_wrap,
|
||||
args=(fn, i, args, error_queue),
|
||||
args=(fn, i, args, tf.name),
|
||||
daemon=daemon,
|
||||
)
|
||||
process.start()
|
||||
error_queues.append(error_queue)
|
||||
error_files.append(tf.name)
|
||||
processes.append(process)
|
||||
|
||||
context = ProcessContext(processes, error_queues)
|
||||
context = ProcessContext(processes, error_files)
|
||||
if not join:
|
||||
return context
|
||||
|
||||
|
Reference in New Issue
Block a user