Files
pytorch/torch/csrc/utils/device_lazy_init.cpp
Yu, Guangye 5c46600f84 [RELAND] refactor lazy init to device-agnostic (#119248)
# Motivation
This PR intends to extend `cuda_lazy_init` to `device_lazy_init` which is a device-agnostic API that can support any backend. And change `maybe_initialize_cuda` to `maybe_initialize_device` to support lazy initialization for CUDA while maintaining scalability.

# Design
We maintain a flag for each backend to manage the lazy initialization state separately.

# Additional Context
No need more UTs.
This is a reland PR, the original PR is [refactor lazy init to device-agnostic](https://github.com/pytorch/pytorch/pull/118846).
This is a common PR, and does not trigger xpu ciflow.

Differential Revision: [D53478332](https://our.internmc.facebook.com/intern/diff/D53478332)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119248
Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/jgong5, https://github.com/atalman
2024-02-07 15:58:51 +00:00

43 lines
1.2 KiB
C++

#include <torch/csrc/utils/device_lazy_init.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
#include <iostream>
namespace torch::utils {
namespace {
std::array<bool, at::COMPILE_TIME_MAX_DEVICE_TYPES> is_initialized{};
} // anonymous namespace
void device_lazy_init(at::DeviceType device_type) {
pybind11::gil_scoped_acquire g;
// Protected by the GIL. We don't use call_once because under ASAN it
// has a buggy implementation that deadlocks if an instance throws an
// exception. In any case, call_once isn't necessary, because we
// have taken a lock.
if (is_initialized[static_cast<int>(device_type)]) {
return;
}
std::string module_name = "torch." + at::DeviceTypeName(device_type, true);
auto module = THPObjectPtr(PyImport_ImportModule(module_name.c_str()));
if (!module) {
throw python_error();
}
auto res = THPObjectPtr(PyObject_CallMethod(module.get(), "_lazy_init", ""));
if (!res) {
throw python_error();
}
is_initialized[static_cast<int>(device_type)] = true;
}
void set_requires_device_init(at::DeviceType device_type, bool value) {
is_initialized[static_cast<int>(device_type)] = !value;
}
} // namespace torch::utils