Generalize poison fork logic for each device backend (#144664)

# Motivation
Generalize the posion_fork code to make it reusable across different devices.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144664
Approved by: https://github.com/EikanWang, https://github.com/albanD
This commit is contained in:
Yu, Guangye
2025-04-11 10:49:10 +00:00
committed by PyTorch MergeBot
parent 304633152c
commit b0810168a3
6 changed files with 73 additions and 108 deletions

View File

@ -51,32 +51,9 @@
#include <sstream>
#include <thread>
#include <unordered_map>
#ifndef WIN32
#include <pthread.h>
#endif
using namespace torch;
static bool in_bad_fork = false; // True for children forked after cuda init
#ifndef WIN32
// Called in the forked child if cuda has already been initialized
static void forked_child() {
in_bad_fork = true;
torch::utils::set_requires_device_init(at::kCUDA, true);
}
#endif
// Should be called before the first cuda call.
// Note: This is distinct from initExtension because a stub cuda implementation
// has some working functions (e.g. device_count) but cannot fully initialize.
static void poison_fork() {
#ifndef WIN32
static auto result [[maybe_unused]] =
pthread_atfork(nullptr, nullptr, forked_child);
#endif
}
////////////////////////////////////////////////////////////////////////////////
// CUDA management methods
////////////////////////////////////////////////////////////////////////////////
@ -160,14 +137,17 @@ PyObject* THCPModule_canDeviceAccessPeer_wrap(PyObject* self, PyObject* args) {
PyObject* THCPModule_getDeviceCount_wrap(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
poison_fork();
// Note: This is distinct from initExtension because a stub cuda
// implementation has some working functions (e.g. device_count) but cannot
// fully initialize.
torch::utils::register_fork_handler_for_device_init(at::kCUDA);
return THPUtils_packUInt64(at::cuda::device_count());
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_getArchFlags(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
poison_fork();
torch::utils::register_fork_handler_for_device_init(at::kCUDA);
#ifdef CUDA_ARCH_FLAGS
static const char* flags = C10_STRINGIZE(CUDA_ARCH_FLAGS);
return THPUtils_packString(flags);
@ -179,7 +159,7 @@ PyObject* THCPModule_getArchFlags(PyObject* self, PyObject* noargs) {
static PyObject* THCPModule_isInBadFork(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
return PyBool_FromLong(in_bad_fork);
return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kCUDA));
END_HANDLE_TH_ERRORS
}
@ -1513,8 +1493,8 @@ static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) {
"please rebuild pytorch without asan if you need to use this module");
#endif
HANDLE_TH_ERRORS
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
poison_fork();
TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kCUDA));
torch::utils::register_fork_handler_for_device_init(at::kCUDA);
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda"));

View File

@ -6,16 +6,12 @@
#include <torch/csrc/THP.h>
#include <torch/csrc/mps/Module.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/device_lazy_init.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
#include <memory>
// pthread.h is included for tracking bad forks
#ifndef WIN32
#include <pthread.h>
#endif
#ifdef USE_MPS
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/MetalShaderLibrary.h>
@ -23,27 +19,9 @@
namespace torch::mps {
namespace {
// True for children forked after mps init
static bool in_bad_fork = false;
// Called in the forked child if mps has already been initialized
static void forked_mps_child() {
in_bad_fork = true;
}
// Should be called before the first mps call.
static void track_bad_mps_fork() {
#ifndef WIN32
static auto result [[maybe_unused]] =
pthread_atfork(nullptr, nullptr, forked_mps_child);
#endif
}
} // namespace
static PyObject* MPSModule_isInBadFork(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
return PyBool_FromLong(in_bad_fork);
return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kMPS));
END_HANDLE_TH_ERRORS
}
@ -51,7 +29,7 @@ static PyObject* MPSModule_getDefaultMPSGenerator(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
track_bad_mps_fork();
torch::utils::register_fork_handler_for_device_init(at::kMPS);
return THPGenerator_initDefaultGenerator(
at::detail::getMPSHooks().getDefaultGenerator());
END_HANDLE_TH_ERRORS
@ -59,8 +37,8 @@ static PyObject* MPSModule_getDefaultMPSGenerator(
static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
track_bad_mps_fork();
if (at::detail::getMPSHooks().hasMPS()) {
torch::utils::register_fork_handler_for_device_init(at::kMPS);
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;

View File

@ -7,38 +7,15 @@
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/device_lazy_init.h>
#include <torch/csrc/utils/pybind.h>
#ifndef WIN32
#include <pthread.h>
#endif
namespace torch::mtia {
static bool in_bad_fork = false; // True for children forked after mtia init
#ifndef WIN32
// Called in the forked child if mtia has already been initialized
static void forked_child() {
in_bad_fork = true;
torch::utils::set_requires_device_init(at::kMTIA, true);
}
#endif
// Should be called before the first mtia call.
// Note: This is distinct from initExtension because a stub mtia implementation
// has some working functions (e.g. device_count) but cannot fully initialize.
static void poison_fork() {
#ifndef WIN32
static auto result [[maybe_unused]] =
pthread_atfork(nullptr, nullptr, forked_child);
#endif
}
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("_mtia_init", []() {
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
poison_fork();
TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kMTIA));
torch::utils::register_fork_handler_for_device_init(at::kMTIA);
at::globalContext().lazyInitDevice(c10::DeviceType::MTIA);
});
@ -47,7 +24,9 @@ void initModule(PyObject* module) {
return at::detail::isMTIAHooksBuilt();
});
m.def("_mtia_isInBadFork", []() { return in_bad_fork; });
m.def("_mtia_isInBadFork", []() {
return torch::utils::is_device_in_bad_fork(at::kMTIA);
});
m.def("_mtia_getCurrentStream", [](c10::DeviceIndex device_index) {
torch::utils::device_lazy_init(at::kMTIA);

View File

@ -1,13 +1,22 @@
#include <c10/core/impl/TorchDispatchModeTLS.h>
#include <c10/util/CallOnce.h>
#include <torch/csrc/utils/device_lazy_init.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
#ifndef WIN32
#include <pthread.h>
#endif
namespace torch::utils {
namespace {
std::array<bool, at::COMPILE_TIME_MAX_DEVICE_TYPES> is_initialized{};
std::array<bool, at::COMPILE_TIME_MAX_DEVICE_TYPES> is_in_bad_fork{};
std::array<c10::once_flag, at::COMPILE_TIME_MAX_DEVICE_TYPES>
at_fork_once_flags{};
} // anonymous namespace
@ -58,4 +67,28 @@ void set_requires_device_init(at::DeviceType device_type, bool value) {
is_initialized[static_cast<int>(device_type)] = !value;
}
bool is_device_in_bad_fork(at::DeviceType device_type) {
return is_in_bad_fork[static_cast<int>(device_type)];
}
void set_device_in_bad_fork(at::DeviceType device_type, bool value) {
is_in_bad_fork[static_cast<int>(device_type)] = value;
}
// Should be called before the first device runtime call.
void register_fork_handler_for_device_init(at::DeviceType device_type) {
#ifndef WIN32
auto& flag = at_fork_once_flags[static_cast<int>(device_type)];
c10::call_once(flag, [device_type]() {
static at::DeviceType at_fork_device_type = device_type;
pthread_atfork(nullptr, nullptr, []() {
set_device_in_bad_fork(at_fork_device_type, true);
if (is_device_lazy_init_supported(at_fork_device_type)) {
set_requires_device_init(at_fork_device_type, true);
}
});
});
#endif
}
} // namespace torch::utils

View File

@ -67,4 +67,21 @@ inline void maybe_initialize_device(
bool is_device_initialized(at::DeviceType device_type);
TORCH_PYTHON_API bool is_device_in_bad_fork(at::DeviceType device_type);
TORCH_PYTHON_API void set_device_in_bad_fork(
at::DeviceType device_type,
bool value);
TORCH_PYTHON_API void register_fork_handler_for_device_init(
at::DeviceType device_type);
inline void maybe_register_fork_handler_for_device_init(
std::optional<at::DeviceType>& device_type) {
if (!device_type.has_value()) {
return;
}
register_fork_handler_for_device_init(device_type.value());
}
} // namespace torch::utils

View File

@ -11,32 +11,8 @@
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/xpu/Module.h>
#ifndef WIN32
#include <pthread.h>
#endif
using namespace torch;
static bool in_bad_fork = false; // True for children forked after xpu init
#ifndef WIN32
// Called in the forked child if xpu has already been initialized
static void forked_child() {
in_bad_fork = true;
torch::utils::set_requires_device_init(at::kXPU, true);
}
#endif
// Should be called before the first xpu call. It is mainly called in lazy_init.
// Note: This is distinct from initExtension because a stub xpu implementation
// has some working functions (e.g. device_count) but cannot fully initialize.
static void poison_fork() {
#ifndef WIN32
static auto result [[maybe_unused]] =
pthread_atfork(nullptr, nullptr, forked_child);
#endif
}
// XPU management methods
static PyObject* THXPModule_getArchFlags(PyObject* self, PyObject* noargs) {
@ -52,7 +28,7 @@ static PyObject* THXPModule_getArchFlags(PyObject* self, PyObject* noargs) {
static PyObject* THXPModule_isInBadFork_wrap(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
return PyBool_FromLong(in_bad_fork);
return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kXPU));
END_HANDLE_TH_ERRORS
}
@ -115,7 +91,9 @@ static PyObject* THXPModule_getDeviceCount_wrap(
PyObject* self,
PyObject* noargs) {
HANDLE_TH_ERRORS
poison_fork();
// Note: This is distinct from initExtension because a stub xpu implementation
// has some working functions (e.g. device_count) but cannot fully initialize.
torch::utils::register_fork_handler_for_device_init(at::kXPU);
return THPUtils_packUInt64(at::xpu::device_count());
END_HANDLE_TH_ERRORS
}
@ -420,8 +398,8 @@ static void initXpuMethodBindings(PyObject* module) {
// classes
static PyObject* THXPModule_initExtension(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
poison_fork();
TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kXPU));
torch::utils::register_fork_handler_for_device_init(at::kXPU);
at::globalContext().lazyInitDevice(c10::DeviceType::XPU);
auto m = THPObjectPtr(PyImport_ImportModule("torch.xpu"));