[Intel GPU] Support getStreamFromExternel for XPU. (#140268)

In AOT inductor scenario, the GPU Stream can be created outside of the pool of `XPUStream`, and we need to create a `XPUStream` which refers to this stream for the the common logic of AOTI, for example a stream guard is a guard for `XPUStream`.  So we add the getStreamFromExternel following the design of CUDAStream.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140268
Approved by: https://github.com/desertfire, https://github.com/jansel, https://github.com/EikanWang
This commit is contained in:
xinan.lin
2024-12-07 06:30:41 -08:00
committed by PyTorch MergeBot
parent 843018f407
commit 3d227ae315
3 changed files with 119 additions and 6 deletions

View File

@ -37,21 +37,28 @@ thread_local std::unique_ptr<StreamId[]> current_streams = nullptr;
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
// How do we assign stream IDs?
//
// -- 57 bits -- -- 5 bits ----- -- 3 bits --
// zeros StreamIdIndex StreamIdType
// -- 56 bits -- -- 5 bits ----- -- 3 bits -- -- 1 bits --
// zeros StreamIdIndex StreamIdType Ext/native stream
//
// Where StreamIdType:
// 000 = normal priority queue
// 001 = high priority queue
//
// for external stream, StreamID is a sycl::queue* pointer
// this means that last bit will always be 0
// so when constructing StreamId for a native stream we set last bit to 1
// to distinguish between native and external streams
//
// StreamId is 64-bit, so we can just rely on regular promotion rules.
// We rely on StreamIdIndex and StreamIdType being non-negative;
using StreamIdIndex = uint8_t;
enum class StreamIdType : uint8_t {
// The higher the type number, the higher the priority.
// EXT is used for external streams, which we don't know the priority of.
NORMAL = 0x0,
HIGH = 0X1,
EXT = 0x7,
};
inline std::ostream& operator<<(std::ostream& stream, StreamIdType q) {
@ -60,6 +67,8 @@ inline std::ostream& operator<<(std::ostream& stream, StreamIdType q) {
return stream << "NORMAL";
case StreamIdType::HIGH:
return stream << "HIGH";
case StreamIdType::EXT:
return stream << "EXT";
default:
break;
}
@ -67,8 +76,13 @@ inline std::ostream& operator<<(std::ostream& stream, StreamIdType q) {
}
inline StreamIdType streamIdType(StreamId s) {
// Externally allocated streams have their id being the sycl:queue* pointer
// so the last bit will be 0
if ((!(s & 1) && s)) {
return StreamIdType(StreamIdType::EXT);
}
int mask_for_type = (1 << kStreamTypeBits) - 1;
auto st = static_cast<StreamIdType>(s & mask_for_type);
auto st = static_cast<StreamIdType>((s >> 1) & mask_for_type);
TORCH_CHECK(
st == StreamIdType::NORMAL || st == StreamIdType::HIGH,
"invalid StreamId: ",
@ -78,12 +92,12 @@ inline StreamIdType streamIdType(StreamId s) {
inline StreamIdIndex streamIdIndex(StreamId s) {
return static_cast<StreamIdIndex>(
(s >> kStreamTypeBits) & ((1 << kStreamsPerPoolBits) - 1));
(s >> (kStreamTypeBits + 1)) & ((1 << kStreamsPerPoolBits) - 1));
}
inline StreamId makeStreamId(StreamIdType st, StreamIdIndex si) {
return (static_cast<StreamId>(si) << kStreamTypeBits) |
static_cast<StreamId>(st);
return (static_cast<StreamId>(si) << (kStreamTypeBits + 1)) |
(static_cast<StreamId>(st) << 1) | 1;
}
void initGlobalStreamState() {
@ -166,6 +180,14 @@ XPUStream XPUStreamForId(DeviceIndex device_index, StreamId stream_id) {
int XPUStream::priority() const {
StreamId stream_id = stream_.id();
StreamIdType st = streamIdType(stream_id);
// For an external queue which is not created in XPUStream, we can not trace
// the priority. Workaround here since sycl doesn't support get priority from
// a sycl::queue, like cudaStreamGetPriority .
// TODO: remove this workaround when sycl supports get priority from a
// sycl::queue.
if (st == StreamIdType::EXT) {
st = StreamIdType::NORMAL;
}
// StreamIdType and priority number are inversely related.
return -static_cast<int>(st);
}
@ -177,6 +199,8 @@ sycl::queue& XPUStream::queue() const {
StreamIdType st = streamIdType(stream_id);
StreamIdIndex si = streamIdIndex(stream_id);
switch (st) {
case StreamIdType::EXT:
return *(reinterpret_cast<sycl::queue*>(stream_id));
case StreamIdType::NORMAL:
case StreamIdType::HIGH:
return *streams[device_index][static_cast<uint8_t>(st)][si];
@ -221,6 +245,15 @@ XPUStream getStreamFromPool(const bool isHighPriority, DeviceIndex device) {
return getStreamFromPool(priority, device);
}
XPUStream getStreamFromExternal(
sycl::queue* ext_stream,
DeviceIndex device_index) {
// The sycl::queue* will be the actual id
TORCH_CHECK(ext_stream, "External stream must not be a nullptr.");
return XPUStreamForId(device_index, reinterpret_cast<int64_t>(ext_stream));
}
// Note: The stream pools will be initialized if needed, at the first invocation
// to this function.
XPUStream getCurrentXPUStream(DeviceIndex device) {

View File

@ -157,6 +157,16 @@ getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1);
C10_XPU_API XPUStream
getStreamFromPool(const int priority, DeviceIndex device = -1);
/**
* Get a XPUStream from a externally allocated one.
*
* This is mainly for interoperability with different libraries where we
* want to operate on a non-torch allocated stream for data exchange or similar
* purposes
*/
C10_API XPUStream
getStreamFromExternal(sycl::queue* ext_stream, DeviceIndex device_index);
/**
* Get the current XPU stream, for the passed XPU device, or for the current
* device if no device index is passed.

View File

@ -1,8 +1,11 @@
#include <gtest/gtest.h>
#include <c10/core/DeviceGuard.h>
#include <c10/util/irange.h>
#include <c10/xpu/XPUException.h>
#include <c10/xpu/XPUStream.h>
#include <c10/xpu/test/impl/XPUTest.h>
#include <optional>
#include <thread>
@ -178,3 +181,70 @@ TEST(XPUStreamTest, StreamFunction) {
validateHostData(hostData, numel);
sycl::free(deviceData, c10::xpu::get_device_context());
}
// Verifies external streams can be created and used
TEST(XPUStreamTest, ExternalTest) {
if (!has_xpu()) {
return;
}
c10::DeviceGuard device_guard(c10::Device(c10::DeviceType::XPU, 0));
using namespace sycl::ext::oneapi::property;
sycl::queue* stream = new sycl::queue(
c10::xpu::get_device_context(),
c10::xpu::get_raw_device(0),
c10::xpu::asyncHandler,
{sycl::property::queue::in_order(), queue::priority_normal()});
at::xpu::XPUStream myStream = at::xpu::getStreamFromExternal(stream, 0);
at::xpu::setCurrentXPUStream(myStream);
at::xpu::XPUStream curStream = at::xpu::getCurrentXPUStream();
ASSERT_TRUE(curStream == myStream);
ASSERT_TRUE(&(curStream.queue()) == stream);
delete stream;
}
// Verifies different external streams can be used for different devices at the
// same time
TEST(XPUStreamTest, ExternalMultiDeviceTest) {
if (!has_xpu()) {
return;
}
if (c10::xpu::device_count() < 2)
return;
sycl::queue* stream_0 = nullptr;
sycl::queue* stream_1 = nullptr;
using namespace sycl::ext::oneapi::property;
{
c10::DeviceGuard device_guard(c10::Device(c10::DeviceType::XPU, 0));
stream_0 = new sycl::queue(
c10::xpu::get_device_context(),
c10::xpu::get_raw_device(0),
c10::xpu::asyncHandler,
{sycl::property::queue::in_order(), queue::priority_normal()});
}
{
c10::DeviceGuard device_guard(c10::Device(c10::DeviceType::XPU, 1));
stream_0 = new sycl::queue(
c10::xpu::get_device_context(),
c10::xpu::get_raw_device(1),
c10::xpu::asyncHandler,
{sycl::property::queue::in_order(), queue::priority_normal()});
}
at::xpu::XPUStream myStream0 = at::xpu::getStreamFromExternal(stream_0, 0);
at::xpu::XPUStream myStream1 = at::xpu::getStreamFromExternal(stream_1, 1);
at::xpu::setCurrentXPUStream(myStream0);
ASSERT_TRUE(at::xpu::getCurrentXPUStream(0) == myStream0);
at::xpu::setCurrentXPUStream(myStream1);
ASSERT_TRUE(at::xpu::getCurrentXPUStream(0) == myStream0);
ASSERT_TRUE(at::xpu::getCurrentXPUStream(1) == myStream1);
delete stream_0;
delete stream_1;
}