Move torch.cuda's atfork handler into C++ (#29101)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/23401

We cannot rely on `multiprocessing.util.register_after_fork` since it is only
called for processes created by the `multiprocessing` module and not `os.fork()`.

Moving to `pthread_atfork` does always get called. However, I don't think it's safe to call python functions inside of the `atfork` handler so the python code has to be a bit more careful when checking `_initialized`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29101

Differential Revision: D18355451

Pulled By: ezyang

fbshipit-source-id: 4d4253a3669796212c099dad4e5bdfdb0df40469
This commit is contained in:
Peter Bell
2019-11-11 07:32:28 -08:00
committed by Facebook Github Bot
parent be757957ba
commit bb119d957e
4 changed files with 52 additions and 43 deletions

View File

@ -22,9 +22,33 @@
#include <torch/csrc/Generator.h>
#include <torch/csrc/python_headers.h>
#ifndef WIN32
#include <pthread.h>
#endif
using namespace torch;
THCState *state;
THCState *state = nullptr;
static bool in_bad_fork = false; // True for children forked after cuda init
#ifndef WIN32
// Called in the forked child if cuda has already been initialized
static void forked_child() {
in_bad_fork = true;
utils::set_run_yet_variable_to_false();
state = nullptr;
}
#endif
// Should be called before the first cuda call.
// Note: This is distinct from initExtension because a stub cuda implementation
// has some working functions (e.g. device_count) but cannot fully initialize.
static void poison_fork() {
#ifndef WIN32
static std::once_flag flag;
std::call_once(flag, []{ pthread_atfork(nullptr, nullptr, forked_child); });
#endif
}
////////////////////////////////////////////////////////////////////////////////
// CUDA management methods
@ -61,16 +85,14 @@ PyObject * THCPModule_getDevice_wrap(PyObject *self, PyObject *noargs)
PyObject * THCPModule_getDeviceCount_wrap(PyObject *self, PyObject *noargs)
{
HANDLE_TH_ERRORS
//torch::utils::cuda_lazy_init();
poison_fork();
return PyLong_FromLong(at::cuda::device_count());
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_set_run_yet_variable_to_false_wrap(PyObject *self, PyObject *noargs)
{
static PyObject * THCPModule_isInBadFork(PyObject *self, PyObject *noargs) {
HANDLE_TH_ERRORS
torch::utils::set_run_yet_variable_to_false();
Py_RETURN_NONE;
return PyBool_FromLong(in_bad_fork);
END_HANDLE_TH_ERRORS
}
@ -373,6 +395,8 @@ static void bindCudaDeviceProperties(PyObject* module) {
static PyObject * THCPModule_initExtension(PyObject *self, PyObject *noargs)
{
HANDLE_TH_ERRORS
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
poison_fork();
state = at::globalContext().lazyInitCUDA();
auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda"));
@ -446,8 +470,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
{"_cuda_setDevice", (PyCFunction)THCPModule_setDevice_wrap, METH_O, nullptr},
{"_cuda_getDevice", (PyCFunction)THCPModule_getDevice_wrap, METH_NOARGS, nullptr},
{"_cuda_getDeviceCount", (PyCFunction)THCPModule_getDeviceCount_wrap, METH_NOARGS, nullptr},
{"_cuda_set_run_yet_variable_to_false",
(PyCFunction)THCPModule_set_run_yet_variable_to_false_wrap, METH_NOARGS, nullptr},
{"_cuda_isInBadFork", (PyCFunction)THCPModule_isInBadFork, METH_NOARGS, nullptr},
{"_cuda_getCurrentStream",
(PyCFunction)THCPModule_getCurrentStream_wrap, METH_O, nullptr},
{"_cuda_getDefaultStream",