mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refine CUDA Stream priority (#143849)
# Motivation As mentioned in https://github.com/pytorch/pytorch/pull/141119#discussion_r1897480515, we properly handle the priority value if it is outside of the priority range. # Additional Context If the value falls outside of the allowed priority range, it will automatically be mapped to the nearest valid priority(either lowest or highest). Pull Request resolved: https://github.com/pytorch/pytorch/pull/143849 Approved by: https://github.com/albanD, https://github.com/EikanWang ghstack dependencies: #142347, #141119, #141123, #143799
This commit is contained in:
committed by
PyTorch MergeBot
parent
3848de55ed
commit
09e47ab7ab
@ -69,6 +69,24 @@ TEST(TestStream, CopyAndMoveTest) {
|
|||||||
ASSERT_EQ_CUDA(moveStream.stream(), cuda_stream);
|
ASSERT_EQ_CUDA(moveStream.stream(), cuda_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Verifies stream priority is handled properly
|
||||||
|
TEST(TestStream, StreamPriorityTest) {
|
||||||
|
if (!at::cuda::is_available()) return;
|
||||||
|
auto [least_priority, greatest_priority] =
|
||||||
|
at::cuda::CUDAStream::priority_range();
|
||||||
|
EXPECT_EQ(least_priority, 0);
|
||||||
|
|
||||||
|
auto stream = at::cuda::getStreamFromPool(-1);
|
||||||
|
EXPECT_EQ(stream.priority(), -1);
|
||||||
|
EXPECT_GT(10, at::cuda::max_compile_time_stream_priorities);
|
||||||
|
stream = at::cuda::getStreamFromPool(-10);
|
||||||
|
EXPECT_EQ(stream.priority(), greatest_priority);
|
||||||
|
stream = at::cuda::getStreamFromPool(0);
|
||||||
|
EXPECT_EQ(stream.priority(), 0);
|
||||||
|
stream = at::cuda::getStreamFromPool(10);
|
||||||
|
EXPECT_EQ(stream.priority(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
// Verifies streams are set properly
|
// Verifies streams are set properly
|
||||||
TEST(TestStream, GetAndSetTest) {
|
TEST(TestStream, GetAndSetTest) {
|
||||||
if (!at::cuda::is_available()) return;
|
if (!at::cuda::is_available()) return;
|
||||||
|
@ -318,10 +318,6 @@ CUDAStream getStreamFromPool(const int priority, DeviceIndex device_index) {
|
|||||||
device_index = current_device();
|
device_index = current_device();
|
||||||
c10::cuda::SetTargetDevice();
|
c10::cuda::SetTargetDevice();
|
||||||
}
|
}
|
||||||
TORCH_CHECK(
|
|
||||||
priority <= 0,
|
|
||||||
"Expected cuda stream priority to be less than or equal to 0, got ",
|
|
||||||
priority);
|
|
||||||
check_gpu(device_index);
|
check_gpu(device_index);
|
||||||
#if !defined(USE_ROCM)
|
#if !defined(USE_ROCM)
|
||||||
// See Note [HIP Lazy Streams]
|
// See Note [HIP Lazy Streams]
|
||||||
@ -329,9 +325,7 @@ CUDAStream getStreamFromPool(const int priority, DeviceIndex device_index) {
|
|||||||
c10::call_once(
|
c10::call_once(
|
||||||
device_flags[device_index], initDeviceStreamState, device_index);
|
device_flags[device_index], initDeviceStreamState, device_index);
|
||||||
#endif
|
#endif
|
||||||
auto pri_idx = -priority;
|
auto pri_idx = std::clamp(-priority, 0, max_stream_priorities - 1);
|
||||||
pri_idx =
|
|
||||||
std::min(pri_idx, max_stream_priorities - 1); // pri_idx is zero-based
|
|
||||||
const auto idx = get_idx(priority_counters[pri_idx][device_index]);
|
const auto idx = get_idx(priority_counters[pri_idx][device_index]);
|
||||||
StreamIdType id_type = StreamIdType(pri_idx + 1);
|
StreamIdType id_type = StreamIdType(pri_idx + 1);
|
||||||
return CUDAStreamForId(device_index, makeStreamId(id_type, idx));
|
return CUDAStreamForId(device_index, makeStreamId(id_type, idx));
|
||||||
|
@ -22,9 +22,11 @@ class Stream(torch._C._CudaStreamBase):
|
|||||||
device(torch.device or int, optional): a device on which to allocate
|
device(torch.device or int, optional): a device on which to allocate
|
||||||
the stream. If :attr:`device` is ``None`` (default) or a negative
|
the stream. If :attr:`device` is ``None`` (default) or a negative
|
||||||
integer, this will use the current device.
|
integer, this will use the current device.
|
||||||
priority(int, optional): priority of the stream, should be 0 or
|
priority(int, optional): priority of the stream, which can be positive, 0, or negative.
|
||||||
negative, where negative numbers indicate higher priority. By default,
|
A lower number indicates a higher priority. By default, the priority is set to 0.
|
||||||
streams have priority 0.
|
If the value falls outside of the allowed priority range, it will automatically be
|
||||||
|
mapped to the nearest valid priority (lowest for large positive numbers or
|
||||||
|
highest for large negative numbers).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user