mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Move torch.cuda's atfork handler into C++ (#29101)
Summary: Fixes https://github.com/pytorch/pytorch/issues/23401 We cannot rely on `multiprocessing.util.register_after_fork` since it is only called for processes created by the `multiprocessing` module and not `os.fork()`. Moving to `pthread_atfork` does always get called. However, I don't think it's safe to call python functions inside of the `atfork` handler so the python code has to be a bit more careful when checking `_initialized`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/29101 Differential Revision: D18355451 Pulled By: ezyang fbshipit-source-id: 4d4253a3669796212c099dad4e5bdfdb0df40469
This commit is contained in:
committed by
Facebook Github Bot
parent
be757957ba
commit
bb119d957e
@ -22,9 +22,33 @@
|
||||
#include <torch/csrc/Generator.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
#ifndef WIN32
|
||||
#include <pthread.h>
|
||||
#endif
|
||||
|
||||
using namespace torch;
|
||||
|
||||
THCState *state;
|
||||
THCState *state = nullptr;
|
||||
static bool in_bad_fork = false; // True for children forked after cuda init
|
||||
|
||||
#ifndef WIN32
|
||||
// Called in the forked child if cuda has already been initialized
|
||||
static void forked_child() {
|
||||
in_bad_fork = true;
|
||||
utils::set_run_yet_variable_to_false();
|
||||
state = nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Should be called before the first cuda call.
|
||||
// Note: This is distinct from initExtension because a stub cuda implementation
|
||||
// has some working functions (e.g. device_count) but cannot fully initialize.
|
||||
static void poison_fork() {
|
||||
#ifndef WIN32
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag, []{ pthread_atfork(nullptr, nullptr, forked_child); });
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// CUDA management methods
|
||||
@ -61,16 +85,14 @@ PyObject * THCPModule_getDevice_wrap(PyObject *self, PyObject *noargs)
|
||||
PyObject * THCPModule_getDeviceCount_wrap(PyObject *self, PyObject *noargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
//torch::utils::cuda_lazy_init();
|
||||
poison_fork();
|
||||
return PyLong_FromLong(at::cuda::device_count());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject * THCPModule_set_run_yet_variable_to_false_wrap(PyObject *self, PyObject *noargs)
|
||||
{
|
||||
static PyObject * THCPModule_isInBadFork(PyObject *self, PyObject *noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
torch::utils::set_run_yet_variable_to_false();
|
||||
Py_RETURN_NONE;
|
||||
return PyBool_FromLong(in_bad_fork);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -373,6 +395,8 @@ static void bindCudaDeviceProperties(PyObject* module) {
|
||||
static PyObject * THCPModule_initExtension(PyObject *self, PyObject *noargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
|
||||
poison_fork();
|
||||
state = at::globalContext().lazyInitCUDA();
|
||||
|
||||
auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda"));
|
||||
@ -446,8 +470,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
|
||||
{"_cuda_setDevice", (PyCFunction)THCPModule_setDevice_wrap, METH_O, nullptr},
|
||||
{"_cuda_getDevice", (PyCFunction)THCPModule_getDevice_wrap, METH_NOARGS, nullptr},
|
||||
{"_cuda_getDeviceCount", (PyCFunction)THCPModule_getDeviceCount_wrap, METH_NOARGS, nullptr},
|
||||
{"_cuda_set_run_yet_variable_to_false",
|
||||
(PyCFunction)THCPModule_set_run_yet_variable_to_false_wrap, METH_NOARGS, nullptr},
|
||||
{"_cuda_isInBadFork", (PyCFunction)THCPModule_isInBadFork, METH_NOARGS, nullptr},
|
||||
{"_cuda_getCurrentStream",
|
||||
(PyCFunction)THCPModule_getCurrentStream_wrap, METH_O, nullptr},
|
||||
{"_cuda_getDefaultStream",
|
||||
|
||||
Reference in New Issue
Block a user