mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Let's have some fun. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78828 Approved by: https://github.com/ezyang
60 lines
1.4 KiB
C++
60 lines
1.4 KiB
C++
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
#include <stdexcept>
|
|
|
|
#if defined(__linux__)
|
|
#include <sys/prctl.h>
|
|
#endif
|
|
|
|
#define SYSASSERT(rv, ...) \
|
|
if ((rv) < 0) { \
|
|
throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); \
|
|
}
|
|
|
|
namespace torch {
|
|
namespace multiprocessing {
|
|
|
|
namespace {
|
|
|
|
PyObject* multiprocessing_init(PyObject* _unused, PyObject* noargs) {
|
|
auto multiprocessing_module =
|
|
THPObjectPtr(PyImport_ImportModule("torch.multiprocessing"));
|
|
if (!multiprocessing_module) {
|
|
throw python_error();
|
|
}
|
|
|
|
auto module = py::handle(multiprocessing_module).cast<py::module>();
|
|
|
|
module.def("_prctl_pr_set_pdeathsig", [](int signal) {
|
|
#if defined(__linux__)
|
|
auto rv = prctl(PR_SET_PDEATHSIG, signal);
|
|
SYSASSERT(rv, "prctl");
|
|
#endif
|
|
});
|
|
|
|
Py_RETURN_TRUE;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
// multiprocessing methods on torch._C
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
|
static PyMethodDef methods[] = {
|
|
{
|
|
"_multiprocessing_init",
|
|
multiprocessing_init,
|
|
METH_NOARGS,
|
|
nullptr,
|
|
},
|
|
{nullptr, nullptr, 0, nullptr},
|
|
};
|
|
|
|
PyMethodDef* python_functions() {
|
|
return methods;
|
|
}
|
|
|
|
} // namespace multiprocessing
|
|
} // namespace torch
|