mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[2/2] Intel GPU Runtime Upstreaming for Stream (#117619)
# Motivation According to [[1/2] Intel GPU Runtime Upstreaming for Stream](https://github.com/pytorch/pytorch/pull/117611), as mentioned in [[RFC] Intel GPU Runtime Upstreaming](https://github.com/pytorch/pytorch/issues/114842), the second PR covers the changes under `python frontend`. # Design Currently, it primarily offers stream-related APIs, including - `torch.xpu.StreamContext` - `torch.xpu.current_stream` - `torch.xpu.set_stream` - `torch.xpu.synchronize` - `torch._C._xpu_getCurrentRawStream` # Additional Context We will implement functions like `torch.xpu.Stream.wait_event`, `torch.xpu.Stream.wait_stream`, and `torch.xpu.Stream.record_event` in the next PR related with `Event`. The differences with CUDA: no default and external stream in XPU and lack of below APIs: - `torch.cuda.ExternalStream` - `torch.cuda.default_stream` - `toch.cuda.is_current_stream_capturing` Pull Request resolved: https://github.com/pytorch/pytorch/pull/117619 Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/gujinghui, https://github.com/albanD ghstack dependencies: #117611
This commit is contained in:
committed by
PyTorch MergeBot
parent
f2778e3874
commit
8fd11cb307
@ -58,6 +58,10 @@ struct TORCH_API XPUHooksInterface {
|
||||
virtual Device getDeviceFromPtr(void* /*data*/) const {
|
||||
TORCH_CHECK(false, "Cannot get device of pointer on XPU without ATen_xpu library.");
|
||||
}
|
||||
|
||||
virtual void deviceSynchronize(DeviceIndex /*device_index*/) const {
|
||||
TORCH_CHECK(false, "Cannot synchronize XPU device without ATen_xpu library.");
|
||||
}
|
||||
};
|
||||
|
||||
struct TORCH_API XPUHooksArgs {};
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/xpu/XPUFunctions.h>
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
|
||||
namespace at::xpu {
|
||||
|
||||
|
@ -30,6 +30,12 @@ int XPUHooks::getNumGPUs() const {
|
||||
return at::xpu::device_count();
|
||||
}
|
||||
|
||||
void XPUHooks::deviceSynchronize(DeviceIndex device_index) const {
|
||||
// Only the SYCL queues we have reserved will be synchronized, see Note
|
||||
// [Synchronize Streams on Device].
|
||||
c10::xpu::syncStreamsOnDevice(device_index);
|
||||
}
|
||||
|
||||
REGISTER_XPU_HOOKS(XPUHooks);
|
||||
|
||||
} // namespace at::xpu::detail
|
||||
|
@ -16,6 +16,7 @@ struct XPUHooks : public at::XPUHooksInterface {
|
||||
int getGlobalIdxFromDevice(const at::Device& device) const override;
|
||||
Device getDeviceFromPtr(void* data) const override;
|
||||
int getNumGPUs() const override;
|
||||
void deviceSynchronize(DeviceIndex device_index) const override;
|
||||
};
|
||||
|
||||
} // namespace at::xpu::detail
|
||||
|
@ -771,6 +771,7 @@ libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [
|
||||
|
||||
libtorch_python_xpu_sources = [
|
||||
"torch/csrc/xpu/Module.cpp",
|
||||
"torch/csrc/xpu/Stream.cpp",
|
||||
]
|
||||
|
||||
libtorch_python_core_sources = [
|
||||
|
@ -7,7 +7,9 @@ torch.xpu
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
StreamContext
|
||||
current_device
|
||||
current_stream
|
||||
device
|
||||
device_count
|
||||
device_of
|
||||
@ -17,4 +19,20 @@ torch.xpu
|
||||
init
|
||||
is_available
|
||||
is_initialized
|
||||
set_device
|
||||
set_device
|
||||
set_stream
|
||||
stream
|
||||
synchronize
|
||||
|
||||
Streams
|
||||
------------------
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
Stream
|
||||
|
||||
|
||||
.. This module needs to be documented. Adding here in the meantime
|
||||
.. for tracking purposes
|
||||
.. py:module:: torch.xpu.streams
|
@ -68,6 +68,30 @@ if __name__ == "__main__":
|
||||
)
|
||||
self.assertRegex(stderr, "Cannot re-initialize XPU in forked subprocess.")
|
||||
|
||||
def test_streams(self):
|
||||
s0 = torch.xpu.Stream()
|
||||
torch.xpu.set_stream(s0)
|
||||
s1 = torch.xpu.current_stream()
|
||||
self.assertEqual(s0, s1)
|
||||
s2 = torch.xpu.Stream()
|
||||
self.assertFalse(s0 == s2)
|
||||
torch.xpu.set_stream(s2)
|
||||
with torch.xpu.stream(s0):
|
||||
self.assertEqual(s0, torch.xpu.current_stream())
|
||||
self.assertEqual(s2, torch.xpu.current_stream())
|
||||
|
||||
def test_stream_priority(self):
|
||||
low, high = torch.xpu.Stream.priority_range()
|
||||
s0 = torch.xpu.Stream(device=0, priority=low)
|
||||
|
||||
self.assertEqual(low, s0.priority)
|
||||
self.assertEqual(torch.device("xpu:0"), s0.device)
|
||||
|
||||
s1 = torch.xpu.Stream(device=0, priority=high)
|
||||
|
||||
self.assertEqual(high, s1.priority)
|
||||
self.assertEqual(torch.device("xpu:0"), s1.device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -1881,6 +1881,10 @@ def _xpu_maybeExchangeDevice(device: _int) -> _int: ...
|
||||
def _xpu_getDevice() -> _int: ...
|
||||
def _xpu_getDeviceCount() -> _int: ...
|
||||
def _xpu_init() -> None: ...
|
||||
def _xpu_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
|
||||
def _xpu_getCurrentStream(device: _int) -> Tuple: ...
|
||||
def _xpu_getCurrentRawStream(device: _int) -> _int: ...
|
||||
def _xpu_synchronize(device: _int) -> None: ...
|
||||
|
||||
class _XpuDeviceProperties:
|
||||
name: str
|
||||
@ -1894,6 +1898,28 @@ class _XpuDeviceProperties:
|
||||
sub_group_sizes: List[_int]
|
||||
type: str
|
||||
|
||||
# Defined in torch/csrc/xpu/Stream.cpp
|
||||
class _XpuStreamBase(Stream):
|
||||
stream_id: _int
|
||||
device_index: _int
|
||||
device_type: _int
|
||||
|
||||
device: _device
|
||||
sycl_queue: _int
|
||||
priority: _int
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
priority: _int = 0,
|
||||
stream_id: _int = 0,
|
||||
device_index: _int = 0,
|
||||
device_type: _int = 0,
|
||||
) -> _XpuStreamBase: ...
|
||||
def query(self) -> _bool: ...
|
||||
def synchronize(self) -> None: ...
|
||||
@staticmethod
|
||||
def priority_range() -> Tuple: ...
|
||||
|
||||
# Defined in torch/csrc/DataLoader.cpp
|
||||
def _set_worker_signal_handlers(
|
||||
*arg: Any,
|
||||
|
@ -1443,6 +1443,7 @@ void initModule(PyObject* module);
|
||||
|
||||
#ifdef USE_XPU
|
||||
PyMethodDef* THXPModule_methods();
|
||||
void THXPStream_init(PyObject* module);
|
||||
namespace torch::xpu {
|
||||
void initModule(PyObject* module);
|
||||
} // namespace torch::xpu
|
||||
@ -1586,6 +1587,10 @@ PyObject* initModule() {
|
||||
THCPGraph_init(module);
|
||||
#endif
|
||||
|
||||
#ifdef USE_XPU
|
||||
THXPStream_init(module);
|
||||
#endif
|
||||
|
||||
auto set_module_attr =
|
||||
[&](const char* name, PyObject* v, bool incref = true) {
|
||||
// PyModule_AddObject steals reference
|
||||
|
@ -13,4 +13,7 @@ size_t TORCH_API device_count();
|
||||
/// Returns true if at least one XPU device is available.
|
||||
bool TORCH_API is_available();
|
||||
|
||||
/// Waits for all kernels in all streams on a XPU device to complete.
|
||||
void TORCH_API synchronize(int64_t device_index);
|
||||
|
||||
} // namespace torch::xpu
|
||||
|
@ -11,4 +11,10 @@ bool is_available() {
|
||||
return xpu::device_count() > 0;
|
||||
}
|
||||
|
||||
void synchronize(int64_t device_index) {
|
||||
TORCH_CHECK(is_available(), "No XPU are available");
|
||||
at::detail::getXPUHooks().deviceSynchronize(
|
||||
static_cast<c10::DeviceIndex>(device_index));
|
||||
}
|
||||
|
||||
} // namespace torch::xpu
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <torch/csrc/Module.h>
|
||||
#include <torch/csrc/THP.h>
|
||||
#include <torch/csrc/utils/device_lazy_init.h>
|
||||
#include <torch/csrc/utils/pycfunction_helpers.h>
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
|
||||
@ -97,6 +98,89 @@ PyObject* THXPModule_getDeviceCount_wrap(PyObject* self, PyObject* noargs) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THXPModule_getCurrentStream_wrap(
|
||||
PyObject* self,
|
||||
PyObject* device_index) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(
|
||||
THPUtils_checkLong(device_index), "invalid argument to current_stream");
|
||||
int64_t device = THPUtils_unpackLong(device_index);
|
||||
auto stream = at::xpu::getCurrentXPUStream(device);
|
||||
PyObject* output_tuple = PyTuple_New(3);
|
||||
PyTuple_SetItem(
|
||||
output_tuple, 0, THPUtils_packInt64(static_cast<int64_t>(stream.id())));
|
||||
PyTuple_SetItem(
|
||||
output_tuple,
|
||||
1,
|
||||
THPUtils_packInt64(static_cast<int64_t>(stream.device_index())));
|
||||
PyTuple_SetItem(
|
||||
output_tuple,
|
||||
2,
|
||||
THPUtils_packInt64(static_cast<int64_t>(stream.device_type())));
|
||||
return output_tuple;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THXPModule_getCurrentStream_raw(
|
||||
PyObject* self,
|
||||
PyObject* device_index) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(
|
||||
THPUtils_checkLong(device_index),
|
||||
"invalid argument to getCurrentRawStream");
|
||||
int64_t device = THPUtils_unpackLong(device_index);
|
||||
return PyLong_FromVoidPtr(&at::xpu::getCurrentXPUStream(device).queue());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THXPModule_setStream_wrap(
|
||||
PyObject* self,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
int64_t stream_id = 0;
|
||||
int64_t device_index = 0;
|
||||
int64_t device_type = 0;
|
||||
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
constexpr const char* kwlist[] = {
|
||||
"stream_id", "device_index", "device_type", nullptr};
|
||||
if (!PyArg_ParseTupleAndKeywords(
|
||||
args,
|
||||
kwargs,
|
||||
"|LLL",
|
||||
const_cast<char**>(kwlist),
|
||||
&stream_id,
|
||||
&device_index,
|
||||
&device_type)) {
|
||||
}
|
||||
|
||||
auto stream = at::xpu::XPUStream::unpack3(
|
||||
stream_id, device_index, static_cast<c10::DeviceType>(device_type));
|
||||
|
||||
auto device = c10::xpu::current_device();
|
||||
if (device != stream.device_index()) {
|
||||
c10::xpu::set_device(stream.device_index());
|
||||
}
|
||||
at::xpu::setCurrentXPUStream(stream);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THXPModule_xpuSynchronize(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to synchronize");
|
||||
int device = THPUtils_unpackInt(arg);
|
||||
{
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
// Only the SYCL queues we have reserved will be synchronized, see Note
|
||||
// [Synchronize Streams on Device].
|
||||
c10::xpu::syncStreamsOnDevice(static_cast<c10::DeviceIndex>(device));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// XPU module initialization
|
||||
|
||||
static void registerXpuDeviceProperties(PyObject* module) {
|
||||
@ -204,6 +288,19 @@ static struct PyMethodDef _THXPModule_methods[] = {
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_xpu_isInBadFork", THXPModule_isInBadFork_wrap, METH_NOARGS, nullptr},
|
||||
{"_xpu_getCurrentStream",
|
||||
THXPModule_getCurrentStream_wrap,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_xpu_getCurrentRawStream",
|
||||
THXPModule_getCurrentStream_raw,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_xpu_setStream",
|
||||
castPyCFunctionWithKeywords(THXPModule_setStream_wrap),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_xpu_synchronize", THXPModule_xpuSynchronize, METH_O, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
PyMethodDef* THXPModule_methods() {
|
||||
|
201
torch/csrc/xpu/Stream.cpp
Normal file
201
torch/csrc/xpu/Stream.cpp
Normal file
@ -0,0 +1,201 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/csrc/Device.h>
|
||||
#include <torch/csrc/THP.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
#include <torch/csrc/xpu/Module.h>
|
||||
#include <torch/csrc/xpu/Stream.h>
|
||||
|
||||
#include <structmember.h>
|
||||
|
||||
PyObject* THXPStreamClass = nullptr;
|
||||
|
||||
static PyObject* THXPStream_pynew(
|
||||
PyTypeObject* type,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
|
||||
const auto current_device = c10::xpu::current_device();
|
||||
|
||||
int32_t priority = 0;
|
||||
int64_t stream_id = 0;
|
||||
int64_t device_index = 0;
|
||||
int64_t device_type = 0;
|
||||
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
constexpr const char* kwlist[] = {
|
||||
"priority", "stream_id", "device_index", "device_type", nullptr};
|
||||
if (!PyArg_ParseTupleAndKeywords(
|
||||
args,
|
||||
kwargs,
|
||||
"|iLLL",
|
||||
const_cast<char**>(kwlist),
|
||||
&priority,
|
||||
&stream_id,
|
||||
&device_index,
|
||||
&device_type)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
THPObjectPtr ptr(type->tp_alloc(type, 0));
|
||||
if (!ptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
at::xpu::XPUStream stream = (stream_id || device_index || device_type)
|
||||
? at::xpu::XPUStream::unpack3(
|
||||
stream_id, device_index, static_cast<c10::DeviceType>(device_type))
|
||||
: at::xpu::getStreamFromPool(priority, current_device);
|
||||
|
||||
THXPStream* self = (THXPStream*)ptr.get();
|
||||
self->stream_id = static_cast<int64_t>(stream.id());
|
||||
self->device_index = static_cast<int64_t>(stream.device_index());
|
||||
self->device_type = static_cast<int64_t>(stream.device_type());
|
||||
new (&self->xpu_stream) at::xpu::XPUStream(stream);
|
||||
|
||||
return (PyObject*)ptr.release();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static void THXPStream_dealloc(THXPStream* self) {
|
||||
self->xpu_stream.~XPUStream();
|
||||
Py_TYPE(self)->tp_free((PyObject*)self);
|
||||
}
|
||||
|
||||
static PyObject* THXPStream_get_device(THXPStream* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
return THPDevice_New(self->xpu_stream.device());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THXPStream_get_sycl_queue(THXPStream* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
return PyLong_FromVoidPtr(&self->xpu_stream.queue());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THXPStream_get_priority(THXPStream* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
return THPUtils_packInt64(self->xpu_stream.priority());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THXPStream_priority_range(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto [least_priority, greatest_priority] =
|
||||
at::xpu::XPUStream::priority_range();
|
||||
return Py_BuildValue("(ii)", least_priority, greatest_priority);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THXPStream_query(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto* self = (THXPStream*)_self;
|
||||
return PyBool_FromLong(self->xpu_stream.query());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THXPStream_synchronize(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS {
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
auto* self = (THXPStream*)_self;
|
||||
self->xpu_stream.synchronize();
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THXPStream_eq(PyObject* _self, PyObject* _other) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto* self = (THXPStream*)_self;
|
||||
auto* other = (THXPStream*)_other;
|
||||
return PyBool_FromLong(self->xpu_stream == other->xpu_stream);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,
|
||||
// cppcoreguidelines-avoid-non-const-global-variables,
|
||||
// cppcoreguidelines-avoid-c-arrays)
|
||||
static struct PyMemberDef THXPStream_members[] = {{nullptr}};
|
||||
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,
|
||||
// cppcoreguidelines-avoid-non-const-global-variables,
|
||||
// cppcoreguidelines-avoid-c-arrays)
|
||||
static struct PyGetSetDef THXPStream_properties[] = {
|
||||
{"sycl_queue",
|
||||
(getter)THXPStream_get_sycl_queue,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"priority", (getter)THXPStream_get_priority, nullptr, nullptr, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,
|
||||
// cppcoreguidelines-avoid-non-const-global-variables,
|
||||
// cppcoreguidelines-avoid-c-arrays)
|
||||
static PyMethodDef THXPStream_methods[] = {
|
||||
{(char*)"query", THXPStream_query, METH_NOARGS, nullptr},
|
||||
{(char*)"synchronize", THXPStream_synchronize, METH_NOARGS, nullptr},
|
||||
{(char*)"priority_range",
|
||||
THXPStream_priority_range,
|
||||
METH_STATIC | METH_NOARGS,
|
||||
nullptr},
|
||||
{(char*)"__eq__", THXPStream_eq, METH_O, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
PyTypeObject THXPStreamType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._XpuStreamBase", /* tp_name */
|
||||
sizeof(THXPStream), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
(destructor)THXPStream_dealloc, /* tp_dealloc */
|
||||
0, /* tp_vectorcall_offset */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
nullptr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
nullptr, /* tp_hash */
|
||||
nullptr, /* tp_call */
|
||||
nullptr, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
nullptr, /* tp_setattro */
|
||||
nullptr, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
|
||||
nullptr, /* tp_doc */
|
||||
nullptr, /* tp_traverse */
|
||||
nullptr, /* tp_clear */
|
||||
nullptr, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
THXPStream_methods, /* tp_methods */
|
||||
THXPStream_members, /* tp_members */
|
||||
THXPStream_properties, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
nullptr, /* tp_descr_get */
|
||||
nullptr, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
nullptr, /* tp_init */
|
||||
nullptr, /* tp_alloc */
|
||||
THXPStream_pynew, /* tp_new */
|
||||
};
|
||||
|
||||
void THXPStream_init(PyObject* module) {
|
||||
Py_INCREF(THPStreamClass);
|
||||
THXPStreamType.tp_base = THPStreamClass;
|
||||
THXPStreamClass = (PyObject*)&THXPStreamType;
|
||||
if (PyType_Ready(&THXPStreamType) < 0) {
|
||||
throw python_error();
|
||||
}
|
||||
Py_INCREF(&THXPStreamType);
|
||||
if (PyModule_AddObject(module, "_XpuStreamBase", (PyObject*)&THXPStreamType) <
|
||||
0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
16
torch/csrc/xpu/Stream.h
Normal file
16
torch/csrc/xpu/Stream.h
Normal file
@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
#include <torch/csrc/Stream.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
struct THXPStream : THPStream {
|
||||
at::xpu::XPUStream xpu_stream;
|
||||
};
|
||||
extern PyObject* THXPStreamClass;
|
||||
|
||||
void THXPStream_init(PyObject* module);
|
||||
|
||||
inline bool THXPStream_Check(PyObject* obj) {
|
||||
return THXPStreamClass && PyObject_IsInstance(obj, THXPStreamClass);
|
||||
}
|
@ -13,6 +13,7 @@ import torch
|
||||
import torch._C
|
||||
from .. import device as _device
|
||||
from ._utils import _dummy_type, _get_device_index
|
||||
from .streams import Stream
|
||||
|
||||
_initialized = False
|
||||
_initialization_lock = threading.Lock()
|
||||
@ -228,17 +229,140 @@ def _get_device(device: Union[int, str, torch.device]) -> torch.device:
|
||||
return device
|
||||
|
||||
|
||||
class StreamContext:
|
||||
r"""Context-manager that selects a given stream.
|
||||
|
||||
All XPU kernels queued within its context will be enqueued on a selected
|
||||
stream.
|
||||
|
||||
Args:
|
||||
Stream (Stream): selected stream. This manager is a no-op if it's
|
||||
``None``.
|
||||
.. note:: Streams are per-device.
|
||||
"""
|
||||
cur_stream: Optional["torch.xpu.Stream"]
|
||||
|
||||
def __init__(self, stream: Optional["torch.xpu.Stream"]):
|
||||
self.stream = stream
|
||||
self.idx = _get_device_index(None, True)
|
||||
if self.idx is None:
|
||||
self.idx = -1
|
||||
|
||||
def __enter__(self):
|
||||
cur_stream = self.stream
|
||||
if cur_stream is None or self.idx == -1:
|
||||
return
|
||||
self.src_prev_stream = torch.xpu.current_stream(None)
|
||||
|
||||
# If the stream is not on the current device, then set the current stream on the device
|
||||
if self.src_prev_stream.device != cur_stream.device:
|
||||
with device(cur_stream.device):
|
||||
self.dst_prev_stream = torch.xpu.current_stream(cur_stream.device)
|
||||
torch.xpu.set_stream(cur_stream)
|
||||
|
||||
def __exit__(self, type: Any, value: Any, traceback: Any):
|
||||
cur_stream = self.stream
|
||||
if cur_stream is None or self.idx == -1:
|
||||
return
|
||||
|
||||
# Reset the stream on the original device and destination device
|
||||
if self.src_prev_stream.device != cur_stream.device:
|
||||
torch.xpu.set_stream(self.dst_prev_stream)
|
||||
torch.xpu.set_stream(self.src_prev_stream)
|
||||
|
||||
|
||||
def stream(stream: Optional["torch.xpu.Stream"]) -> StreamContext:
|
||||
r"""Wrap around the Context-manager StreamContext that selects a given stream.
|
||||
|
||||
Arguments:
|
||||
stream (Stream): selected stream. This manager is a no-op if it's ``None``.
|
||||
"""
|
||||
return StreamContext(stream)
|
||||
|
||||
|
||||
def _set_stream_by_id(stream_id, device_index, device_type):
|
||||
r"""set stream specified by the stream id, device index and device type
|
||||
|
||||
Args: stream_id (int): not visible to the user, used to assigned to the specific stream.
|
||||
device_index (int): selected device index.
|
||||
device_type (int): selected device type.
|
||||
"""
|
||||
torch._C._xpu_setStream(
|
||||
stream_id=stream_id,
|
||||
device_index=device_index,
|
||||
device_type=device_type,
|
||||
)
|
||||
|
||||
|
||||
def set_stream(stream: Stream):
|
||||
r"""Set the current stream.This is a wrapper API to set the stream.
|
||||
Usage of this function is discouraged in favor of the ``stream``
|
||||
context manager.
|
||||
|
||||
Args:
|
||||
stream (Stream): selected stream. This function is a no-op
|
||||
if this argument is ``None``.
|
||||
"""
|
||||
if stream is None:
|
||||
return
|
||||
_lazy_init()
|
||||
_set_stream_by_id(
|
||||
stream_id=stream.stream_id,
|
||||
device_index=stream.device_index,
|
||||
device_type=stream.device_type,
|
||||
)
|
||||
|
||||
|
||||
def current_stream(device: Optional[_device_t] = None) -> Stream:
|
||||
r"""Return the currently selected :class:`Stream` for a given device.
|
||||
|
||||
Args:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
the currently selected :class:`Stream` for the current device, given
|
||||
by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
|
||||
(default).
|
||||
"""
|
||||
_lazy_init()
|
||||
streamdata = torch._C._xpu_getCurrentStream(
|
||||
_get_device_index(device, optional=True)
|
||||
)
|
||||
return Stream(
|
||||
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
|
||||
)
|
||||
|
||||
|
||||
def synchronize(device: _device_t = None) -> None:
|
||||
r"""Wait for all kernels in all streams on a XPU device to complete.
|
||||
|
||||
Args:
|
||||
device (torch.device or int, optional): device for which to synchronize.
|
||||
It uses the current device, given by :func:`~torch.xpu.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
"""
|
||||
_lazy_init()
|
||||
device = _get_device_index(device)
|
||||
return torch._C._xpu_synchronize(device)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Stream",
|
||||
"StreamContext",
|
||||
"current_device",
|
||||
"current_stream",
|
||||
"device",
|
||||
"device_of",
|
||||
"device_count",
|
||||
"get_device_capability",
|
||||
"get_device_name",
|
||||
"get_device_properties",
|
||||
"get_stream",
|
||||
"init",
|
||||
"is_available",
|
||||
"is_bf16_supported",
|
||||
"is_initialized",
|
||||
"set_device",
|
||||
"set_stream",
|
||||
"stream",
|
||||
"streams",
|
||||
"synchronize",
|
||||
]
|
||||
|
70
torch/xpu/streams.py
Normal file
70
torch/xpu/streams.py
Normal file
@ -0,0 +1,70 @@
|
||||
import ctypes
|
||||
|
||||
import torch
|
||||
from torch._streambase import _StreamBase
|
||||
from ._utils import _dummy_type
|
||||
|
||||
|
||||
if not hasattr(torch._C, "_XpuStreamBase"):
|
||||
# Define dummy base classes
|
||||
torch._C.__dict__["_XpuStreamBase"] = _dummy_type("_XpuStreamBase")
|
||||
|
||||
|
||||
class Stream(torch._C._XpuStreamBase, _StreamBase):
|
||||
r"""Wrapper around a XPU stream.
|
||||
|
||||
A XPU stream is a linear sequence of execution that belongs to a specific
|
||||
device, independent from other streams.
|
||||
|
||||
Args:
|
||||
device(torch.device or int, optional): a device on which to allocate
|
||||
the stream. If :attr:`device` is ``None`` (default) or a negative
|
||||
integer, this will use the current device.
|
||||
priority(int, optional): priority of the stream, should be 0 or
|
||||
negative, where negative numbers indicate higher priority. By default,
|
||||
streams have priority 0.
|
||||
"""
|
||||
|
||||
def __new__(cls, device=None, priority=0, **kwargs):
|
||||
# setting device manager is expensive, so we avoid it unless necessary
|
||||
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
|
||||
return super().__new__(cls, priority=priority, **kwargs)
|
||||
else:
|
||||
with torch.xpu.device(device):
|
||||
return super().__new__(cls, priority=priority, **kwargs)
|
||||
|
||||
def wait_event(self, event):
|
||||
pass
|
||||
|
||||
def wait_stream(self, stream):
|
||||
pass
|
||||
|
||||
def record_event(self, event=None):
|
||||
pass
|
||||
|
||||
def query(self):
|
||||
r"""Check if all the work submitted has been completed.
|
||||
|
||||
Returns:
|
||||
A boolean indicating if all kernels in this stream are completed.
|
||||
"""
|
||||
return super().query()
|
||||
|
||||
def synchronize(self):
|
||||
r"""Wait for all the kernels in this stream to complete."""
|
||||
super().synchronize()
|
||||
|
||||
@property
|
||||
def _as_parameter_(self):
|
||||
return ctypes.c_void_p(self.sycl_queue)
|
||||
|
||||
def __eq__(self, o):
|
||||
if isinstance(o, Stream):
|
||||
return super().__eq__(o)
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.sycl_queue, self.device))
|
||||
|
||||
def __repr__(self):
|
||||
return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"
|
Reference in New Issue
Block a user