mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add low priority XPU Stream (#141119)
# Motivation Due to the potential for the external SYCL queue to have a low priority, we need to support the low-priority SYCL queue for native XPU Streams to maintain consistency. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141119 Approved by: https://github.com/gujinghui, https://github.com/albanD ghstack dependencies: #142347
This commit is contained in:
committed by
PyTorch MergeBot
parent
39450ae655
commit
a68c0ca497
@ -33,32 +33,51 @@ std::deque<
|
||||
|
||||
thread_local std::unique_ptr<StreamId[]> current_streams = nullptr;
|
||||
|
||||
// Note [StreamId assignment]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// How do we assign stream IDs?
|
||||
//
|
||||
// -- 55 bits -- -- 5 bits -- -- 3 bits -- -- 1 bit --
|
||||
// zeros StreamIdIndex StreamIdType Ext/native stream
|
||||
// ignored for ext ignored for ext
|
||||
//
|
||||
// Where StreamIdType:
|
||||
// 000 = normal priority queue
|
||||
// 001 = high priority queue
|
||||
// 111 = external queue
|
||||
//
|
||||
// For external stream, StreamID is a sycl::queue* pointer. This means that last
|
||||
// bit will always be 0. So when constructing StreamId for a native stream we
|
||||
// set last bit to 1 to distinguish between native and external streams. For
|
||||
// more details, see Note [External XPU Stream].
|
||||
//
|
||||
// StreamId is 64-bit, so we can just rely on regular promotion rules.
|
||||
// We rely on StreamIdIndex and StreamIdType being non-negative;
|
||||
/*
|
||||
* Note [StreamId assignment]
|
||||
* ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
* How do we assign stream IDs?
|
||||
*
|
||||
* -- 55 bits -- -- 5 bits -- -- 3 bits -- -- 1 bit --
|
||||
* zeros StreamIdIndex StreamIdType Ext/native stream
|
||||
* ignored for ext ignored for ext
|
||||
*
|
||||
* Where StreamIdType:
|
||||
* 000 = low priority queue
|
||||
* 001 = normal priority queue
|
||||
* 010 = high priority queue
|
||||
* 111 = external queue
|
||||
*
|
||||
* For external stream, StreamID is a sycl::queue* pointer. This means that last
|
||||
* bit will always be 0. So when constructing StreamId for a native stream we
|
||||
* set last bit to 1 to distinguish between native and external streams. For
|
||||
* more details, see Note [External XPU Stream].
|
||||
*
|
||||
* StreamId is 64-bit, so we can just rely on regular promotion rules.
|
||||
* We rely on StreamIdIndex and StreamIdType being non-negative;
|
||||
*/
|
||||
|
||||
/*
|
||||
* Note [XPU Stream priorities]
|
||||
* XPU stream priority levels are defined based on the following design
|
||||
* principles:
|
||||
* 1. Higher priority number indicates lower priority.
|
||||
* 2. The default priority, `normal`, corresponds to a priority number of 0.
|
||||
* 3. StreamIdType and priority number are inversely related.
|
||||
*
|
||||
* This relationship can be summarized as follows:
|
||||
* -- priority type -- -- priority number -- -- type number --
|
||||
* low 1 0
|
||||
* normal 0 1
|
||||
* high -1 2
|
||||
*/
|
||||
|
||||
using StreamIdIndex = uint8_t;
|
||||
enum class StreamIdType : uint8_t {
|
||||
// The higher the type number, the higher the priority for the native stream.
|
||||
NORMAL = 0x0,
|
||||
HIGH = 0X1,
|
||||
LOW = 0x0,
|
||||
NORMAL = 0x1,
|
||||
HIGH = 0x2,
|
||||
// For an external stream, the last bit of StreamId is 0, whose priority is
|
||||
// queried at runtime.
|
||||
EXT = 0x7,
|
||||
@ -66,6 +85,8 @@ enum class StreamIdType : uint8_t {
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& stream, StreamIdType q) {
|
||||
switch (q) {
|
||||
case StreamIdType::LOW:
|
||||
return stream << "LOW";
|
||||
case StreamIdType::NORMAL:
|
||||
return stream << "NORMAL";
|
||||
case StreamIdType::HIGH:
|
||||
@ -87,7 +108,8 @@ inline StreamIdType streamIdType(StreamId s) {
|
||||
int mask_for_type = (1 << kStreamTypeBits) - 1;
|
||||
auto st = static_cast<StreamIdType>((s >> 1) & mask_for_type);
|
||||
TORCH_CHECK(
|
||||
st == StreamIdType::NORMAL || st == StreamIdType::HIGH,
|
||||
st == StreamIdType::NORMAL || st == StreamIdType::HIGH ||
|
||||
st == StreamIdType::LOW,
|
||||
"invalid StreamId: ",
|
||||
s);
|
||||
return st;
|
||||
@ -116,8 +138,12 @@ void initDeviceStreamState(DeviceIndex device) {
|
||||
using namespace sycl::ext::oneapi::property;
|
||||
// Need to align with StreamIdType.
|
||||
const std::vector<sycl::property_list> properties = {
|
||||
{sycl::property::queue::in_order(), queue::priority_low()},
|
||||
{sycl::property::queue::in_order(), queue::priority_normal()},
|
||||
{sycl::property::queue::in_order(), queue::priority_high()}};
|
||||
TORCH_CHECK(
|
||||
properties.size() == max_compile_time_stream_priorities,
|
||||
"The number of stream priorities should be equal to max_compile_time_stream_priorities");
|
||||
for (const auto p : c10::irange(max_compile_time_stream_priorities)) {
|
||||
for (const auto i : c10::irange(kStreamsPerPool)) {
|
||||
auto& stream = streams[device][p][i];
|
||||
@ -186,16 +212,19 @@ int XPUStream::priority() const {
|
||||
if (C10_UNLIKELY(st == StreamIdType::EXT)) {
|
||||
// Query external stream priority
|
||||
using namespace sycl::ext::oneapi::property;
|
||||
// Default priority for SYCL queue is normal.
|
||||
st = StreamIdType::NORMAL;
|
||||
if (queue().has_property<queue::priority_normal>()) {
|
||||
st = StreamIdType::NORMAL;
|
||||
} else if (queue().has_property<queue::priority_high>()) {
|
||||
st = StreamIdType::HIGH;
|
||||
} else if (queue().has_property<queue::priority_low>()) {
|
||||
st = StreamIdType::LOW;
|
||||
} else {
|
||||
// Default priority for SYCL queue is normal.
|
||||
st = StreamIdType::NORMAL;
|
||||
}
|
||||
}
|
||||
// StreamIdType and priority number are inversely related.
|
||||
return -static_cast<int>(st);
|
||||
// See Note [XPU Stream priorities]
|
||||
return -static_cast<int>(st) + 1;
|
||||
}
|
||||
|
||||
// See Note [StreamId assignment]
|
||||
@ -232,14 +261,11 @@ XPUStream getStreamFromPool(const int priority, DeviceIndex device) {
|
||||
device = c10::xpu::current_device();
|
||||
}
|
||||
check_device_index(device);
|
||||
TORCH_CHECK(
|
||||
priority <= 0,
|
||||
"Expected XPU stream priority to be less than or equal to 0, got ",
|
||||
priority);
|
||||
// Initializes the stream pools (once)
|
||||
initDeviceStreamOnce(device);
|
||||
// See Note [XPU Stream priorities]
|
||||
auto priority_idx =
|
||||
std::min(-priority, max_compile_time_stream_priorities - 1);
|
||||
std::clamp(-priority + 1, 0, max_compile_time_stream_priorities - 1);
|
||||
const auto idx = get_idx(priority_counters[device][priority_idx]);
|
||||
auto id_type = static_cast<StreamIdType>(priority_idx);
|
||||
return XPUStreamForId(device, makeStreamId(id_type, idx));
|
||||
@ -248,7 +274,8 @@ XPUStream getStreamFromPool(const int priority, DeviceIndex device) {
|
||||
XPUStream getStreamFromPool(const bool isHighPriority, DeviceIndex device) {
|
||||
initXPUStreamsOnce();
|
||||
// If isHighPriority is true, return the stream with the highest priority.
|
||||
int priority = isHighPriority ? -max_compile_time_stream_priorities + 1 : 0;
|
||||
// See Note [XPU Stream priorities]
|
||||
int priority = isHighPriority ? -max_compile_time_stream_priorities + 2 : 0;
|
||||
return getStreamFromPool(priority, device);
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,7 @@ namespace c10::xpu {
|
||||
* threads as the SYCL specification described.
|
||||
*/
|
||||
|
||||
static constexpr int max_compile_time_stream_priorities = 2;
|
||||
static constexpr int max_compile_time_stream_priorities = 3;
|
||||
|
||||
/*
|
||||
* This serves as a wrapper around c10::Stream and acts as a representation for
|
||||
@ -132,7 +132,8 @@ class C10_XPU_API XPUStream {
|
||||
|
||||
/// Return the range of priority **supported by PyTorch**.
|
||||
static std::tuple<int, int> priority_range() {
|
||||
return std::make_tuple(0, -max_compile_time_stream_priorities + 1);
|
||||
// See Note [XPU Stream priorities]
|
||||
return std::make_tuple(1, -max_compile_time_stream_priorities + 2);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -69,11 +69,24 @@ TEST(XPUStreamTest, StreamBehavior) {
|
||||
|
||||
auto [least_priority, greatest_priority] =
|
||||
c10::xpu::XPUStream::priority_range();
|
||||
EXPECT_EQ(least_priority, 0);
|
||||
EXPECT_TRUE(greatest_priority < 0);
|
||||
EXPECT_EQ(least_priority, 1);
|
||||
EXPECT_EQ(greatest_priority, -1);
|
||||
|
||||
stream = c10::xpu::getStreamFromPool(/* isHighPriority */ true);
|
||||
EXPECT_TRUE(stream.priority() < 0);
|
||||
EXPECT_EQ(stream.priority(), -1);
|
||||
stream = c10::xpu::getStreamFromPool(/* isHighPriority */ false);
|
||||
EXPECT_EQ(stream.priority(), 0);
|
||||
|
||||
stream = c10::xpu::getStreamFromPool(-1);
|
||||
EXPECT_EQ(stream.priority(), -1);
|
||||
stream = c10::xpu::getStreamFromPool(-10);
|
||||
EXPECT_EQ(stream.priority(), -1);
|
||||
stream = c10::xpu::getStreamFromPool(0);
|
||||
EXPECT_EQ(stream.priority(), 0);
|
||||
stream = c10::xpu::getStreamFromPool(1);
|
||||
EXPECT_EQ(stream.priority(), 1);
|
||||
stream = c10::xpu::getStreamFromPool(10);
|
||||
EXPECT_EQ(stream.priority(), 1);
|
||||
|
||||
if (c10::xpu::device_count() <= 1) {
|
||||
return;
|
||||
|
@ -21,9 +21,11 @@ class Stream(torch._C._XpuStreamBase):
|
||||
device(torch.device or int, optional): a device on which to allocate
|
||||
the stream. If :attr:`device` is ``None`` (default) or a negative
|
||||
integer, this will use the current device.
|
||||
priority(int, optional): priority of the stream, should be 0 or
|
||||
negative, where negative numbers indicate higher priority. By default,
|
||||
streams have priority 0.
|
||||
priority(int, optional): priority of the stream, which can be positive, 0, or negative.
|
||||
A lower number indicates a higher priority. By default, the priority is set to 0.
|
||||
If the value falls outside of the allowed priority range, it will automatically be
|
||||
mapped to the nearest valid priority (lowest for large positive numbers or
|
||||
highest for large negative numbers).
|
||||
"""
|
||||
|
||||
def __new__(cls, device=None, priority=0, **kwargs):
|
||||
|
Reference in New Issue
Block a user