mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/119667 Approved by: https://github.com/Skylion007
262 lines
9.0 KiB
C++
262 lines
9.0 KiB
C++
#include <torch/csrc/DataLoader.h>
|
|
|
|
// Together with `torch/utils/data/_utils/signal_handling.py`, the following
|
|
// is an effort to do our best to provide some error message to users when a
|
|
// worker dies due to error / critical signals.
|
|
//
|
|
// See NOTE [ Signal handling in multiprocessing data loading ] for more
|
|
// details.
|
|
|
|
// TODO: The following don't work on Windows. Specifically, sigaction, waitid
|
|
// calls, and SIGCHLD handler. Currently, dummy implementations are provided
|
|
// for Windows.
|
|
|
|
#ifndef _WIN32
|
|
|
|
#include <torch/csrc/Exceptions.h>
|
|
#include <torch/csrc/utils/python_numbers.h>
|
|
|
|
#include <c10/util/irange.h>
|
|
#include <fmt/format.h>
|
|
|
|
#include <sys/wait.h>
|
|
#include <csignal>
|
|
#include <map>
|
|
#include <set>
|
|
#include <sstream>
|
|
|
|
using namespace torch;
|
|
|
|
// Critical signal handlers should be registered on worker processes before
|
|
// doing work.
|
|
// The handler will raise default handler so that the kill information will be
|
|
// retrieved from main process.
|
|
// Python handle is _set_worker_signal_handlers().
|
|
#define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \
|
|
static void HANDLER_NAME(int sig, siginfo_t* info, void* ctx) { \
|
|
auto _w = \
|
|
write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \
|
|
(void)_w; \
|
|
struct sigaction sa {}; \
|
|
sa.sa_handler = SIG_DFL; \
|
|
sa.sa_flags = 0; \
|
|
if (sigemptyset(&sa.sa_mask) != 0 || \
|
|
sigaction(SIGNAL, &sa, nullptr) != 0) { \
|
|
_exit(EXIT_FAILURE); \
|
|
} else { \
|
|
raise(SIGNAL); \
|
|
} \
|
|
}
|
|
|
|
// signal(2) is really not portable. So use sigaction.
|
|
// http://man7.org/linux/man-pages/man2/signal.2.html
|
|
static inline void setSignalHandler(
|
|
int signal,
|
|
void (*handler)(int, siginfo_t*, void*),
|
|
struct sigaction* old_sa_ptr) {
|
|
struct sigaction sa {};
|
|
sa.sa_sigaction = handler;
|
|
sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER;
|
|
if (sigemptyset(&sa.sa_mask) != 0 ||
|
|
sigaction(signal, &sa, old_sa_ptr) != 0) {
|
|
std::ostringstream oss;
|
|
oss << "An error occurred while setting handler for " << strsignal(signal)
|
|
<< ".";
|
|
throw std::runtime_error(oss.str());
|
|
}
|
|
}
|
|
|
|
SIGNAL_HANDLER(
|
|
SIGBUS,
|
|
handler_SIGBUS,
|
|
"ERROR: Unexpected bus error encountered in worker. "
|
|
"This might be caused by insufficient shared memory (shm).\n");
|
|
SIGNAL_HANDLER(
|
|
SIGSEGV,
|
|
handler_SIGSEGV,
|
|
"ERROR: Unexpected segmentation fault encountered in worker.\n");
|
|
SIGNAL_HANDLER(
|
|
SIGFPE,
|
|
handler_SIGFPE,
|
|
"ERROR: Unexpected floating-point exception encountered in worker.\n");
|
|
|
|
// When an error happened in DataLoader methods and Python starts to exit, the
|
|
// error trace will keep the loader alive, and Python may kill the children
|
|
// processes first before deleting the loader object. Then the cleaning up
|
|
// methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an
|
|
// error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main
|
|
// loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we
|
|
// exit with nonzero code, the loader SIGCHLD handler may report RuntimeError
|
|
// again, and then it defeats the whole purpose.
|
|
static void handler_SIGTERM(int sig, siginfo_t* info, void* ctx) {
|
|
if (info->si_pid == getppid()) {
|
|
_exit(EXIT_SUCCESS);
|
|
}
|
|
struct sigaction sa {};
|
|
sa.sa_handler = SIG_DFL;
|
|
sa.sa_flags = 0;
|
|
if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0) {
|
|
_exit(EXIT_FAILURE);
|
|
} else {
|
|
raise(SIGTERM);
|
|
}
|
|
}
|
|
|
|
__attribute__((weak)) void setDataLoaderSignalHandlers() {}
|
|
|
|
static PyObject* THPModule_setWorkerSignalHandlers(
|
|
PyObject* module,
|
|
PyObject* arg) {
|
|
HANDLE_TH_ERRORS
|
|
setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr);
|
|
setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr);
|
|
setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr);
|
|
setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr);
|
|
setDataLoaderSignalHandlers();
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static std::map<int64_t, std::set<pid_t>> worker_pids = {};
|
|
|
|
static PyObject* THPModule_errorIfAnyWorkerFails(
|
|
PyObject* module,
|
|
PyObject* noargs) {
|
|
HANDLE_TH_ERRORS
|
|
|
|
// Only check the pids we care about
|
|
for (auto& w : worker_pids) {
|
|
auto& pid_set = w.second;
|
|
for (auto worker_pid : pid_set) {
|
|
// Use waitid rather than waitpid so that we can set NOWAIT, and that
|
|
// Python and other handlers can get whatever info they want about the
|
|
// child.
|
|
siginfo_t infop{};
|
|
infop.si_pid = 0;
|
|
auto error =
|
|
waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT);
|
|
// ignore errors and case with no waitable child
|
|
if (error < 0 || infop.si_pid == 0)
|
|
continue;
|
|
if (infop.si_code == CLD_EXITED &&
|
|
infop.si_status != EXIT_SUCCESS) { // exit with error
|
|
std::ostringstream oss;
|
|
oss << "DataLoader worker (pid " << worker_pid << ") exited "
|
|
<< "unexpectedly with exit code " << infop.si_status << ". "
|
|
<< "Details are lost due to multiprocessing. Rerunning with "
|
|
<< "num_workers=0 may give better error trace.";
|
|
// This is necessary. Otherwise, the runtime error will kill the other
|
|
// workers, and trigger this again.
|
|
pid_set.clear();
|
|
throw std::runtime_error(oss.str());
|
|
} else if (
|
|
infop.si_code == CLD_KILLED ||
|
|
infop.si_code == CLD_DUMPED) { // killed by signal
|
|
std::ostringstream oss;
|
|
oss << "DataLoader worker (pid " << worker_pid << ") is killed "
|
|
<< "by signal: " << strsignal(infop.si_status) << ". ";
|
|
if (infop.si_status == SIGBUS) {
|
|
oss << "It is possible that dataloader's workers are out of shared memory. "
|
|
<< "Please try to raise your shared memory limit.";
|
|
}
|
|
// This is necessary. Otherwise, the runtime error will kill the other
|
|
// workers, and trigger this again.
|
|
pid_set.clear();
|
|
throw std::runtime_error(oss.str());
|
|
}
|
|
}
|
|
}
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
// We don't want to exit on any SIGCHLD from any child. child_pids is a tuple
|
|
// of pids we are interested in.
|
|
static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* args) {
|
|
HANDLE_TH_ERRORS
|
|
TORCH_CHECK_TYPE(
|
|
PyTuple_GET_SIZE(args) == 2,
|
|
"_set_worker_pids expects exactly 2 arguments.");
|
|
int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
|
|
TORCH_CHECK_VALUE(
|
|
worker_pids.find(key) == worker_pids.end(),
|
|
"_set_worker_pids should be called only once for each _BaseDataLoaderIter.");
|
|
PyObject* child_pids = PyTuple_GET_ITEM(args, 1);
|
|
TORCH_CHECK_TYPE(
|
|
PyTuple_Check(child_pids),
|
|
"_set_worker_pids expects a tuple for child_pids, but got ",
|
|
Py_TYPE(child_pids)->tp_name,
|
|
".");
|
|
std::set<pid_t> pids_set = {};
|
|
auto size = PyTuple_GET_SIZE(child_pids);
|
|
for (const auto idx : c10::irange(size)) {
|
|
PyObject* obj = PyTuple_GET_ITEM(child_pids, idx);
|
|
pids_set.insert(static_cast<pid_t>(THPUtils_unpackLong(obj)));
|
|
}
|
|
|
|
worker_pids[key] = pids_set;
|
|
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static PyObject* THPModule_removeWorkerPIDs(
|
|
PyObject* module,
|
|
PyObject* loader_id) {
|
|
HANDLE_TH_ERRORS
|
|
|
|
int64_t key = THPUtils_unpackLong(loader_id);
|
|
auto it = worker_pids.find(key);
|
|
TORCH_CHECK_VALUE(
|
|
it != worker_pids.end(),
|
|
"Cannot find worker information for _BaseDataLoaderIter with id ",
|
|
key);
|
|
worker_pids.erase(it);
|
|
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
#undef SIGNAL_HANDLER
|
|
|
|
#else
|
|
// dummy implementations for windows
|
|
|
|
static PyObject* THPModule_setWorkerSignalHandlers(
|
|
PyObject* module,
|
|
PyObject* _ignored) {
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* _ignored) {
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
static PyObject* THPModule_removeWorkerPIDs(
|
|
PyObject* module,
|
|
PyObject* _ignored) {
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
static PyObject* THPModule_errorIfAnyWorkerFails(
|
|
PyObject* module,
|
|
PyObject* _ignored) {
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
#endif
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
|
|
PyMethodDef DataLoaderMethods[] = {
|
|
{"_set_worker_signal_handlers",
|
|
THPModule_setWorkerSignalHandlers,
|
|
METH_NOARGS,
|
|
nullptr},
|
|
{"_set_worker_pids", THPModule_setWorkerPIDs, METH_VARARGS, nullptr},
|
|
{"_remove_worker_pids", THPModule_removeWorkerPIDs, METH_O, nullptr},
|
|
{"_error_if_any_worker_fails",
|
|
THPModule_errorIfAnyWorkerFails,
|
|
METH_NOARGS,
|
|
nullptr},
|
|
{nullptr, nullptr, 0, nullptr}};
|