[4/4] Intel GPU Runtime Upstreaming for Device (#116869)

# Motivation
According to [[1/4] Intel GPU Runtime Upstreaming for Device](https://github.com/pytorch/pytorch/pull/116019), as mentioned in [[RFC] Intel GPU Runtime Upstreaming](https://github.com/pytorch/pytorch/issues/114842), this last PR  covers the changes under lazy initialization.

# Design
This PR primarily offers the support of multi-processing via lazy initialization. We lazily initialize our runtime avoiding initializing XPU until the first time it is accessed. In our design, we extend `cuda_lazy_init` to `device_lazy_init` which is a device-agnostic API that can support any backend. And change `maybe_initialize_cuda` to `maybe_initialize_device` to support lazy initialization for both CUDA and XPU while maintaining scalability.

# Additional Context
We adopt a similar design to CUDA. So we share some code with CUDA.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116869
Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/gujinghui, https://github.com/malfet
ghstack dependencies: #119248
This commit is contained in:
Yu, Guangye
2024-02-06 09:41:29 +00:00
committed by PyTorch MergeBot
parent 3cb7ec312c
commit 9a992b0918
6 changed files with 106 additions and 8 deletions

View File

@ -14,5 +14,7 @@ torch.xpu
get_device_capability
get_device_name
get_device_properties
init
is_available
is_initialized
set_device

View File

@ -46,6 +46,28 @@ class TestXpu(TestCase):
self.assertTrue(device_capability["max_work_group_size"] > 0)
self.assertTrue(device_capability["max_num_sub_groups"] > 0)
def test_wrong_xpu_fork(self):
stderr = TestCase.runWithPytorchAPIUsageStderr(
"""\
import torch
from torch.multiprocessing import Process
def run(rank):
torch.xpu.set_device(rank)
if __name__ == "__main__":
size = 2
processes = []
for rank in range(size):
# it would work fine without the line below
torch.xpu.set_device(0)
p = Process(target=run, args=(rank,))
p.start()
processes.append(p)
for p in processes:
p.join()
"""
)
self.assertRegex(stderr, "Cannot re-initialize XPU in forked subprocess.")
if __name__ == "__main__":
run_tests()

View File

@ -518,6 +518,7 @@ static PyObject * THPVariable_xpu(PyObject* self, PyObject* args, PyObject* kwar
auto device = r.isNone(0) ? at::Device(at::DeviceType::XPU) : r.device(0);
auto opt_memory_format = r.memoryformatOptional(2);
TORCH_CHECK(device.is_xpu(), "Invalid device, must be xpu device");
torch::utils::device_lazy_init(at::kXPU);
return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format));
END_HANDLE_TH_ERRORS
}

View File

@ -28,7 +28,7 @@ void set_requires_device_init(at::DeviceType device_type, bool value);
static inline void maybe_initialize_device(at::Device& device) {
// Add more devices here to enable lazy initialization.
if (device.is_cuda()) {
if (device.is_cuda() || device.is_xpu()) {
device_lazy_init(device.type());
}
}

View File

@ -4,13 +4,38 @@
#include <c10/xpu/XPUFunctions.h>
#include <torch/csrc/Module.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/utils/device_lazy_init.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
#include <pthread.h>
using namespace torch;
static bool in_bad_fork = false; // True for children forked after xpu init
// Called in the forked child if xpu has already been initialized
static void forked_child() {
in_bad_fork = true;
torch::utils::set_requires_device_init(at::kXPU, true);
}
// Should be called before the first xpu call. It is mainly called in lazy_init.
// Note: This is distinct from initExtension because a stub xpu implementation
// has some working functions (e.g. device_count) but cannot fully initialize.
static void poison_fork() {
static c10::once_flag flag;
c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); });
}
// XPU management methods
static PyObject* THXPModule_isInBadFork_wrap(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
return PyBool_FromLong(in_bad_fork);
END_HANDLE_TH_ERRORS
}
PyObject* THXPModule_setDevice_wrap(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to set_device");
@ -30,6 +55,8 @@ PyObject* THXPModule_exchangeDevice_wrap(PyObject* self, PyObject* arg) {
if (device < 0) {
return THPUtils_packInt32(-1);
}
torch::utils::device_lazy_init(at::kXPU);
int current_device = c10::xpu::exchange_device(device);
return THPUtils_packInt32(current_device);
@ -45,6 +72,8 @@ PyObject* THXPModule_maybeExchangeDevice_wrap(PyObject* self, PyObject* arg) {
if (device < 0) {
return THPUtils_packInt32(-1);
}
torch::utils::device_lazy_init(at::kXPU);
int current_device = c10::xpu::maybe_exchange_device(device);
return THPUtils_packInt32(current_device);
@ -63,7 +92,7 @@ PyObject* THXPModule_getDevice_wrap(PyObject* self, PyObject* noargs) {
PyObject* THXPModule_getDeviceCount_wrap(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
poison_fork();
return THPUtils_packUInt64(at::xpu::device_count());
END_HANDLE_TH_ERRORS
}
@ -145,6 +174,8 @@ static void bindGetDeviceProperties(PyObject* module) {
// classes
static PyObject* THXPModule_initExtension(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
poison_fork();
auto m = THPObjectPtr(PyImport_ImportModule("torch.xpu"));
if (!m)
@ -172,6 +203,7 @@ static struct PyMethodDef _THXPModule_methods[] = {
THXPModule_getDeviceCount_wrap,
METH_NOARGS,
nullptr},
{"_xpu_isInBadFork", THXPModule_isInBadFork_wrap, METH_NOARGS, nullptr},
{nullptr}};
PyMethodDef* THXPModule_methods() {

View File

@ -2,8 +2,10 @@ r"""
This package introduces support for the XPU backend, specifically tailored for
Intel GPU optimization.
You can use :func:`is_available()` to determine if your system supports XPU.
This package is lazily initialized, so you can always import it, and use
:func:`is_available()` to determine if your system supports XPU.
"""
import threading
from functools import lru_cache
from typing import Any, Dict, Optional, Union
@ -12,6 +14,9 @@ import torch._C
from .. import device as _device
from ._utils import _dummy_type, _get_device_index
_initialized = False
_initialization_lock = threading.Lock()
_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
_device_t = Union[_device, str, int, None]
@ -35,11 +40,6 @@ else:
raise NotImplementedError("PyTorch was compiled without XPU support")
# TODO: Enable lazy init.
if _is_compiled():
torch._C._xpu_init()
@lru_cache(maxsize=1)
def device_count() -> int:
r"""Return the number of XPU device available."""
@ -59,6 +59,42 @@ def is_bf16_supported():
return True
def is_initialized():
r"""Return whether PyTorch's XPU state has been initialized."""
return _initialized and not _is_in_bad_fork()
def init():
r"""Initialize PyTorch's XPU state.
This is a Python API about lazy initialization that avoids initializing
XPU until the first time it is accessed. Does nothing if the XPU state is
already initialized.
"""
_lazy_init()
def _lazy_init():
global _initialized
if is_initialized():
return
with _initialization_lock:
# This test was was protected via GIL. Double-check whether XPU has
# already been initialized.
if is_initialized():
return
# Stop promptly upon encountering a bad fork error.
if _is_in_bad_fork():
raise RuntimeError(
"Cannot re-initialize XPU in forked subprocess. To use XPU with "
"multiprocessing, you must use the 'spawn' start method"
)
if not _is_compiled():
raise AssertionError("Torch not compiled with XPU enabled")
# This function inits XPU backend and detects bad fork processing.
torch._C._xpu_init()
_initialized = True
class _DeviceGuard:
def __init__(self, index: int):
self.idx = index
@ -114,6 +150,7 @@ def set_device(device: _device_t) -> None:
device (torch.device or int or str): selected device. This function is a
no-op if this argument is negative.
"""
_lazy_init()
device = _get_device_index(device)
if device >= 0:
torch._C._xpu_setDevice(device)
@ -165,6 +202,7 @@ def get_device_properties(device: Optional[_device_t] = None) -> _XpuDevicePrope
Returns:
_XpuDeviceProperties: the properties of the device
"""
_lazy_init()
device = _get_device_index(device, optional=True)
if device < 0 or device >= device_count():
raise AssertionError("Invalid device index")
@ -173,6 +211,7 @@ def get_device_properties(device: Optional[_device_t] = None) -> _XpuDevicePrope
def current_device() -> int:
r"""Return the index of a currently selected device."""
_lazy_init()
return torch._C._xpu_getDevice()
@ -197,7 +236,9 @@ __all__ = [
"get_device_capability",
"get_device_name",
"get_device_properties",
"init",
"is_available",
"is_bf16_supported",
"is_initialized",
"set_device",
]