mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[4/4] Intel GPU Runtime Upstreaming for Device (#116869)
# Motivation According to [[1/4] Intel GPU Runtime Upstreaming for Device](https://github.com/pytorch/pytorch/pull/116019), as mentioned in [[RFC] Intel GPU Runtime Upstreaming](https://github.com/pytorch/pytorch/issues/114842), this last PR covers the changes under lazy initialization. # Design This PR primarily offers the support of multi-processing via lazy initialization. We lazily initialize our runtime avoiding initializing XPU until the first time it is accessed. In our design, we extend `cuda_lazy_init` to `device_lazy_init` which is a device-agnostic API that can support any backend. And change `maybe_initialize_cuda` to `maybe_initialize_device` to support lazy initialization for both CUDA and XPU while maintaining scalability. # Additional Context We adopt a similar design to CUDA. So we share some code with CUDA. Pull Request resolved: https://github.com/pytorch/pytorch/pull/116869 Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/gujinghui, https://github.com/malfet ghstack dependencies: #119248
This commit is contained in:
committed by
PyTorch MergeBot
parent
3cb7ec312c
commit
9a992b0918
@ -14,5 +14,7 @@ torch.xpu
|
||||
get_device_capability
|
||||
get_device_name
|
||||
get_device_properties
|
||||
init
|
||||
is_available
|
||||
is_initialized
|
||||
set_device
|
@ -46,6 +46,28 @@ class TestXpu(TestCase):
|
||||
self.assertTrue(device_capability["max_work_group_size"] > 0)
|
||||
self.assertTrue(device_capability["max_num_sub_groups"] > 0)
|
||||
|
||||
def test_wrong_xpu_fork(self):
|
||||
stderr = TestCase.runWithPytorchAPIUsageStderr(
|
||||
"""\
|
||||
import torch
|
||||
from torch.multiprocessing import Process
|
||||
def run(rank):
|
||||
torch.xpu.set_device(rank)
|
||||
if __name__ == "__main__":
|
||||
size = 2
|
||||
processes = []
|
||||
for rank in range(size):
|
||||
# it would work fine without the line below
|
||||
torch.xpu.set_device(0)
|
||||
p = Process(target=run, args=(rank,))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
for p in processes:
|
||||
p.join()
|
||||
"""
|
||||
)
|
||||
self.assertRegex(stderr, "Cannot re-initialize XPU in forked subprocess.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -518,6 +518,7 @@ static PyObject * THPVariable_xpu(PyObject* self, PyObject* args, PyObject* kwar
|
||||
auto device = r.isNone(0) ? at::Device(at::DeviceType::XPU) : r.device(0);
|
||||
auto opt_memory_format = r.memoryformatOptional(2);
|
||||
TORCH_CHECK(device.is_xpu(), "Invalid device, must be xpu device");
|
||||
torch::utils::device_lazy_init(at::kXPU);
|
||||
return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ void set_requires_device_init(at::DeviceType device_type, bool value);
|
||||
|
||||
static inline void maybe_initialize_device(at::Device& device) {
|
||||
// Add more devices here to enable lazy initialization.
|
||||
if (device.is_cuda()) {
|
||||
if (device.is_cuda() || device.is_xpu()) {
|
||||
device_lazy_init(device.type());
|
||||
}
|
||||
}
|
||||
|
@ -4,13 +4,38 @@
|
||||
#include <c10/xpu/XPUFunctions.h>
|
||||
#include <torch/csrc/Module.h>
|
||||
#include <torch/csrc/THP.h>
|
||||
#include <torch/csrc/utils/device_lazy_init.h>
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
|
||||
#include <pthread.h>
|
||||
|
||||
using namespace torch;
|
||||
|
||||
static bool in_bad_fork = false; // True for children forked after xpu init
|
||||
|
||||
// Called in the forked child if xpu has already been initialized
|
||||
static void forked_child() {
|
||||
in_bad_fork = true;
|
||||
torch::utils::set_requires_device_init(at::kXPU, true);
|
||||
}
|
||||
|
||||
// Should be called before the first xpu call. It is mainly called in lazy_init.
|
||||
// Note: This is distinct from initExtension because a stub xpu implementation
|
||||
// has some working functions (e.g. device_count) but cannot fully initialize.
|
||||
static void poison_fork() {
|
||||
static c10::once_flag flag;
|
||||
c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); });
|
||||
}
|
||||
|
||||
// XPU management methods
|
||||
|
||||
static PyObject* THXPModule_isInBadFork_wrap(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
return PyBool_FromLong(in_bad_fork);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THXPModule_setDevice_wrap(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to set_device");
|
||||
@ -30,6 +55,8 @@ PyObject* THXPModule_exchangeDevice_wrap(PyObject* self, PyObject* arg) {
|
||||
if (device < 0) {
|
||||
return THPUtils_packInt32(-1);
|
||||
}
|
||||
|
||||
torch::utils::device_lazy_init(at::kXPU);
|
||||
int current_device = c10::xpu::exchange_device(device);
|
||||
|
||||
return THPUtils_packInt32(current_device);
|
||||
@ -45,6 +72,8 @@ PyObject* THXPModule_maybeExchangeDevice_wrap(PyObject* self, PyObject* arg) {
|
||||
if (device < 0) {
|
||||
return THPUtils_packInt32(-1);
|
||||
}
|
||||
|
||||
torch::utils::device_lazy_init(at::kXPU);
|
||||
int current_device = c10::xpu::maybe_exchange_device(device);
|
||||
|
||||
return THPUtils_packInt32(current_device);
|
||||
@ -63,7 +92,7 @@ PyObject* THXPModule_getDevice_wrap(PyObject* self, PyObject* noargs) {
|
||||
|
||||
PyObject* THXPModule_getDeviceCount_wrap(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
|
||||
poison_fork();
|
||||
return THPUtils_packUInt64(at::xpu::device_count());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -145,6 +174,8 @@ static void bindGetDeviceProperties(PyObject* module) {
|
||||
// classes
|
||||
static PyObject* THXPModule_initExtension(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
|
||||
poison_fork();
|
||||
|
||||
auto m = THPObjectPtr(PyImport_ImportModule("torch.xpu"));
|
||||
if (!m)
|
||||
@ -172,6 +203,7 @@ static struct PyMethodDef _THXPModule_methods[] = {
|
||||
THXPModule_getDeviceCount_wrap,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_xpu_isInBadFork", THXPModule_isInBadFork_wrap, METH_NOARGS, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
PyMethodDef* THXPModule_methods() {
|
||||
|
@ -2,8 +2,10 @@ r"""
|
||||
This package introduces support for the XPU backend, specifically tailored for
|
||||
Intel GPU optimization.
|
||||
|
||||
You can use :func:`is_available()` to determine if your system supports XPU.
|
||||
This package is lazily initialized, so you can always import it, and use
|
||||
:func:`is_available()` to determine if your system supports XPU.
|
||||
"""
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
@ -12,6 +14,9 @@ import torch._C
|
||||
from .. import device as _device
|
||||
from ._utils import _dummy_type, _get_device_index
|
||||
|
||||
_initialized = False
|
||||
_initialization_lock = threading.Lock()
|
||||
_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
|
||||
_device_t = Union[_device, str, int, None]
|
||||
|
||||
|
||||
@ -35,11 +40,6 @@ else:
|
||||
raise NotImplementedError("PyTorch was compiled without XPU support")
|
||||
|
||||
|
||||
# TODO: Enable lazy init.
|
||||
if _is_compiled():
|
||||
torch._C._xpu_init()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def device_count() -> int:
|
||||
r"""Return the number of XPU device available."""
|
||||
@ -59,6 +59,42 @@ def is_bf16_supported():
|
||||
return True
|
||||
|
||||
|
||||
def is_initialized():
|
||||
r"""Return whether PyTorch's XPU state has been initialized."""
|
||||
return _initialized and not _is_in_bad_fork()
|
||||
|
||||
|
||||
def init():
|
||||
r"""Initialize PyTorch's XPU state.
|
||||
This is a Python API about lazy initialization that avoids initializing
|
||||
XPU until the first time it is accessed. Does nothing if the XPU state is
|
||||
already initialized.
|
||||
"""
|
||||
_lazy_init()
|
||||
|
||||
|
||||
def _lazy_init():
|
||||
global _initialized
|
||||
if is_initialized():
|
||||
return
|
||||
with _initialization_lock:
|
||||
# This test was was protected via GIL. Double-check whether XPU has
|
||||
# already been initialized.
|
||||
if is_initialized():
|
||||
return
|
||||
# Stop promptly upon encountering a bad fork error.
|
||||
if _is_in_bad_fork():
|
||||
raise RuntimeError(
|
||||
"Cannot re-initialize XPU in forked subprocess. To use XPU with "
|
||||
"multiprocessing, you must use the 'spawn' start method"
|
||||
)
|
||||
if not _is_compiled():
|
||||
raise AssertionError("Torch not compiled with XPU enabled")
|
||||
# This function inits XPU backend and detects bad fork processing.
|
||||
torch._C._xpu_init()
|
||||
_initialized = True
|
||||
|
||||
|
||||
class _DeviceGuard:
|
||||
def __init__(self, index: int):
|
||||
self.idx = index
|
||||
@ -114,6 +150,7 @@ def set_device(device: _device_t) -> None:
|
||||
device (torch.device or int or str): selected device. This function is a
|
||||
no-op if this argument is negative.
|
||||
"""
|
||||
_lazy_init()
|
||||
device = _get_device_index(device)
|
||||
if device >= 0:
|
||||
torch._C._xpu_setDevice(device)
|
||||
@ -165,6 +202,7 @@ def get_device_properties(device: Optional[_device_t] = None) -> _XpuDevicePrope
|
||||
Returns:
|
||||
_XpuDeviceProperties: the properties of the device
|
||||
"""
|
||||
_lazy_init()
|
||||
device = _get_device_index(device, optional=True)
|
||||
if device < 0 or device >= device_count():
|
||||
raise AssertionError("Invalid device index")
|
||||
@ -173,6 +211,7 @@ def get_device_properties(device: Optional[_device_t] = None) -> _XpuDevicePrope
|
||||
|
||||
def current_device() -> int:
|
||||
r"""Return the index of a currently selected device."""
|
||||
_lazy_init()
|
||||
return torch._C._xpu_getDevice()
|
||||
|
||||
|
||||
@ -197,7 +236,9 @@ __all__ = [
|
||||
"get_device_capability",
|
||||
"get_device_name",
|
||||
"get_device_properties",
|
||||
"init",
|
||||
"is_available",
|
||||
"is_bf16_supported",
|
||||
"is_initialized",
|
||||
"set_device",
|
||||
]
|
||||
|
Reference in New Issue
Block a user