diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index 15e10a9c22ff..fed869c9ae26 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -1,8 +1,11 @@ import logging import multiprocessing import multiprocessing.connection +import os +import pickle import signal import sys +import tempfile import time import warnings from typing import Optional @@ -61,7 +64,7 @@ class ProcessExitedException(ProcessException): ) -def _wrap(fn, i, args, error_queue): +def _wrap(fn, i, args, error_file): # prctl(2) is a Linux specific system call. # On other systems the following function call has no effect. # This is set to ensure that non-daemonic child processes can @@ -76,13 +79,14 @@ def _wrap(fn, i, args, error_queue): # Propagate exception to parent process, keeping original traceback import traceback - error_queue.put(traceback.format_exc()) + with open(error_file, "wb") as fh: + pickle.dump(traceback.format_exc(), fh) sys.exit(1) class ProcessContext: - def __init__(self, processes, error_queues): - self.error_queues = error_queues + def __init__(self, processes, error_files): + self.error_files = error_files self.processes = processes self.sentinels = { process.sentinel: index for index, process in enumerate(processes) @@ -153,9 +157,9 @@ class ProcessContext: process.kill() process.join() - # There won't be an error on the queue if the process crashed. + # The file will only be created if the process crashed. failed_process = self.processes[error_index] - if self.error_queues[error_index].empty(): + if not os.access(self.error_files[error_index], os.R_OK): exitcode = self.processes[error_index].exitcode if exitcode < 0: try: @@ -177,16 +181,17 @@ class ProcessContext: exit_code=exitcode, ) - original_trace = self.error_queues[error_index].get() + with open(self.error_files[error_index], "rb") as fh: + original_trace = pickle.load(fh) msg = "\n\n-- Process %d terminated with the following error:\n" % error_index msg += original_trace raise ProcessRaisedException(msg, error_index, failed_process.pid) class SpawnContext(ProcessContext): - def __init__(self, processes, error_queues): + def __init__(self, processes, error_files): warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.") - super().__init__(processes, error_queues) + super().__init__(processes, error_files) # Note: [start_processes] @@ -201,20 +206,30 @@ def start_processes( fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn" ): mp = multiprocessing.get_context(start_method) - error_queues = [] + error_files = [] processes = [] for i in range(nprocs): - error_queue = mp.SimpleQueue() + # Each process is assigned a file to write tracebacks to. We + # use the file being non-empty to indicate an exception + # occurred (vs an expected shutdown). Note: this previously + # used a multiprocessing.Queue but that can be prone to + # deadlocks, so we went with a simpler solution for a one-shot + # message between processes. + tf = tempfile.NamedTemporaryFile( + prefix="pytorch-errorfile-", suffix=".pickle", delete=False + ) + tf.close() + os.unlink(tf.name) process = mp.Process( target=_wrap, - args=(fn, i, args, error_queue), + args=(fn, i, args, tf.name), daemon=daemon, ) process.start() - error_queues.append(error_queue) + error_files.append(tf.name) processes.append(process) - context = ProcessContext(processes, error_queues) + context = ProcessContext(processes, error_files) if not join: return context