mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Intel GPU] Support getStreamFromExternel for XPU. (#140268)
In AOT inductor scenario, the GPU Stream can be created outside of the pool of `XPUStream`, and we need to create a `XPUStream` which refers to this stream for the the common logic of AOTI, for example a stream guard is a guard for `XPUStream`. So we add the getStreamFromExternel following the design of CUDAStream. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140268 Approved by: https://github.com/desertfire, https://github.com/jansel, https://github.com/EikanWang
This commit is contained in:
committed by
PyTorch MergeBot
parent
843018f407
commit
3d227ae315
@ -37,21 +37,28 @@ thread_local std::unique_ptr<StreamId[]> current_streams = nullptr;
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// How do we assign stream IDs?
|
||||
//
|
||||
// -- 57 bits -- -- 5 bits ----- -- 3 bits --
|
||||
// zeros StreamIdIndex StreamIdType
|
||||
// -- 56 bits -- -- 5 bits ----- -- 3 bits -- -- 1 bits --
|
||||
// zeros StreamIdIndex StreamIdType Ext/native stream
|
||||
//
|
||||
// Where StreamIdType:
|
||||
// 000 = normal priority queue
|
||||
// 001 = high priority queue
|
||||
//
|
||||
// for external stream, StreamID is a sycl::queue* pointer
|
||||
// this means that last bit will always be 0
|
||||
// so when constructing StreamId for a native stream we set last bit to 1
|
||||
// to distinguish between native and external streams
|
||||
//
|
||||
// StreamId is 64-bit, so we can just rely on regular promotion rules.
|
||||
// We rely on StreamIdIndex and StreamIdType being non-negative;
|
||||
|
||||
using StreamIdIndex = uint8_t;
|
||||
enum class StreamIdType : uint8_t {
|
||||
// The higher the type number, the higher the priority.
|
||||
// EXT is used for external streams, which we don't know the priority of.
|
||||
NORMAL = 0x0,
|
||||
HIGH = 0X1,
|
||||
EXT = 0x7,
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& stream, StreamIdType q) {
|
||||
@ -60,6 +67,8 @@ inline std::ostream& operator<<(std::ostream& stream, StreamIdType q) {
|
||||
return stream << "NORMAL";
|
||||
case StreamIdType::HIGH:
|
||||
return stream << "HIGH";
|
||||
case StreamIdType::EXT:
|
||||
return stream << "EXT";
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -67,8 +76,13 @@ inline std::ostream& operator<<(std::ostream& stream, StreamIdType q) {
|
||||
}
|
||||
|
||||
inline StreamIdType streamIdType(StreamId s) {
|
||||
// Externally allocated streams have their id being the sycl:queue* pointer
|
||||
// so the last bit will be 0
|
||||
if ((!(s & 1) && s)) {
|
||||
return StreamIdType(StreamIdType::EXT);
|
||||
}
|
||||
int mask_for_type = (1 << kStreamTypeBits) - 1;
|
||||
auto st = static_cast<StreamIdType>(s & mask_for_type);
|
||||
auto st = static_cast<StreamIdType>((s >> 1) & mask_for_type);
|
||||
TORCH_CHECK(
|
||||
st == StreamIdType::NORMAL || st == StreamIdType::HIGH,
|
||||
"invalid StreamId: ",
|
||||
@ -78,12 +92,12 @@ inline StreamIdType streamIdType(StreamId s) {
|
||||
|
||||
inline StreamIdIndex streamIdIndex(StreamId s) {
|
||||
return static_cast<StreamIdIndex>(
|
||||
(s >> kStreamTypeBits) & ((1 << kStreamsPerPoolBits) - 1));
|
||||
(s >> (kStreamTypeBits + 1)) & ((1 << kStreamsPerPoolBits) - 1));
|
||||
}
|
||||
|
||||
inline StreamId makeStreamId(StreamIdType st, StreamIdIndex si) {
|
||||
return (static_cast<StreamId>(si) << kStreamTypeBits) |
|
||||
static_cast<StreamId>(st);
|
||||
return (static_cast<StreamId>(si) << (kStreamTypeBits + 1)) |
|
||||
(static_cast<StreamId>(st) << 1) | 1;
|
||||
}
|
||||
|
||||
void initGlobalStreamState() {
|
||||
@ -166,6 +180,14 @@ XPUStream XPUStreamForId(DeviceIndex device_index, StreamId stream_id) {
|
||||
int XPUStream::priority() const {
|
||||
StreamId stream_id = stream_.id();
|
||||
StreamIdType st = streamIdType(stream_id);
|
||||
// For an external queue which is not created in XPUStream, we can not trace
|
||||
// the priority. Workaround here since sycl doesn't support get priority from
|
||||
// a sycl::queue, like cudaStreamGetPriority .
|
||||
// TODO: remove this workaround when sycl supports get priority from a
|
||||
// sycl::queue.
|
||||
if (st == StreamIdType::EXT) {
|
||||
st = StreamIdType::NORMAL;
|
||||
}
|
||||
// StreamIdType and priority number are inversely related.
|
||||
return -static_cast<int>(st);
|
||||
}
|
||||
@ -177,6 +199,8 @@ sycl::queue& XPUStream::queue() const {
|
||||
StreamIdType st = streamIdType(stream_id);
|
||||
StreamIdIndex si = streamIdIndex(stream_id);
|
||||
switch (st) {
|
||||
case StreamIdType::EXT:
|
||||
return *(reinterpret_cast<sycl::queue*>(stream_id));
|
||||
case StreamIdType::NORMAL:
|
||||
case StreamIdType::HIGH:
|
||||
return *streams[device_index][static_cast<uint8_t>(st)][si];
|
||||
@ -221,6 +245,15 @@ XPUStream getStreamFromPool(const bool isHighPriority, DeviceIndex device) {
|
||||
return getStreamFromPool(priority, device);
|
||||
}
|
||||
|
||||
XPUStream getStreamFromExternal(
|
||||
sycl::queue* ext_stream,
|
||||
DeviceIndex device_index) {
|
||||
// The sycl::queue* will be the actual id
|
||||
|
||||
TORCH_CHECK(ext_stream, "External stream must not be a nullptr.");
|
||||
return XPUStreamForId(device_index, reinterpret_cast<int64_t>(ext_stream));
|
||||
}
|
||||
|
||||
// Note: The stream pools will be initialized if needed, at the first invocation
|
||||
// to this function.
|
||||
XPUStream getCurrentXPUStream(DeviceIndex device) {
|
||||
|
@ -157,6 +157,16 @@ getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1);
|
||||
C10_XPU_API XPUStream
|
||||
getStreamFromPool(const int priority, DeviceIndex device = -1);
|
||||
|
||||
/**
|
||||
* Get a XPUStream from a externally allocated one.
|
||||
*
|
||||
* This is mainly for interoperability with different libraries where we
|
||||
* want to operate on a non-torch allocated stream for data exchange or similar
|
||||
* purposes
|
||||
*/
|
||||
C10_API XPUStream
|
||||
getStreamFromExternal(sycl::queue* ext_stream, DeviceIndex device_index);
|
||||
|
||||
/**
|
||||
* Get the current XPU stream, for the passed XPU device, or for the current
|
||||
* device if no device index is passed.
|
||||
|
@ -1,8 +1,11 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/xpu/XPUException.h>
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
#include <c10/xpu/test/impl/XPUTest.h>
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include <thread>
|
||||
@ -178,3 +181,70 @@ TEST(XPUStreamTest, StreamFunction) {
|
||||
validateHostData(hostData, numel);
|
||||
sycl::free(deviceData, c10::xpu::get_device_context());
|
||||
}
|
||||
|
||||
// Verifies external streams can be created and used
|
||||
TEST(XPUStreamTest, ExternalTest) {
|
||||
if (!has_xpu()) {
|
||||
return;
|
||||
}
|
||||
|
||||
c10::DeviceGuard device_guard(c10::Device(c10::DeviceType::XPU, 0));
|
||||
|
||||
using namespace sycl::ext::oneapi::property;
|
||||
sycl::queue* stream = new sycl::queue(
|
||||
c10::xpu::get_device_context(),
|
||||
c10::xpu::get_raw_device(0),
|
||||
c10::xpu::asyncHandler,
|
||||
{sycl::property::queue::in_order(), queue::priority_normal()});
|
||||
|
||||
at::xpu::XPUStream myStream = at::xpu::getStreamFromExternal(stream, 0);
|
||||
|
||||
at::xpu::setCurrentXPUStream(myStream);
|
||||
at::xpu::XPUStream curStream = at::xpu::getCurrentXPUStream();
|
||||
|
||||
ASSERT_TRUE(curStream == myStream);
|
||||
ASSERT_TRUE(&(curStream.queue()) == stream);
|
||||
|
||||
delete stream;
|
||||
}
|
||||
|
||||
// Verifies different external streams can be used for different devices at the
|
||||
// same time
|
||||
TEST(XPUStreamTest, ExternalMultiDeviceTest) {
|
||||
if (!has_xpu()) {
|
||||
return;
|
||||
}
|
||||
if (c10::xpu::device_count() < 2)
|
||||
return;
|
||||
sycl::queue* stream_0 = nullptr;
|
||||
sycl::queue* stream_1 = nullptr;
|
||||
|
||||
using namespace sycl::ext::oneapi::property;
|
||||
{
|
||||
c10::DeviceGuard device_guard(c10::Device(c10::DeviceType::XPU, 0));
|
||||
stream_0 = new sycl::queue(
|
||||
c10::xpu::get_device_context(),
|
||||
c10::xpu::get_raw_device(0),
|
||||
c10::xpu::asyncHandler,
|
||||
{sycl::property::queue::in_order(), queue::priority_normal()});
|
||||
}
|
||||
{
|
||||
c10::DeviceGuard device_guard(c10::Device(c10::DeviceType::XPU, 1));
|
||||
stream_0 = new sycl::queue(
|
||||
c10::xpu::get_device_context(),
|
||||
c10::xpu::get_raw_device(1),
|
||||
c10::xpu::asyncHandler,
|
||||
{sycl::property::queue::in_order(), queue::priority_normal()});
|
||||
}
|
||||
at::xpu::XPUStream myStream0 = at::xpu::getStreamFromExternal(stream_0, 0);
|
||||
at::xpu::XPUStream myStream1 = at::xpu::getStreamFromExternal(stream_1, 1);
|
||||
|
||||
at::xpu::setCurrentXPUStream(myStream0);
|
||||
ASSERT_TRUE(at::xpu::getCurrentXPUStream(0) == myStream0);
|
||||
at::xpu::setCurrentXPUStream(myStream1);
|
||||
ASSERT_TRUE(at::xpu::getCurrentXPUStream(0) == myStream0);
|
||||
ASSERT_TRUE(at::xpu::getCurrentXPUStream(1) == myStream1);
|
||||
|
||||
delete stream_0;
|
||||
delete stream_1;
|
||||
}
|
Reference in New Issue
Block a user