diff --git a/aten/src/ATen/test/cuda_stream_test.cpp b/aten/src/ATen/test/cuda_stream_test.cpp index a1944ca4422e..dfd823f27349 100644 --- a/aten/src/ATen/test/cuda_stream_test.cpp +++ b/aten/src/ATen/test/cuda_stream_test.cpp @@ -69,6 +69,24 @@ TEST(TestStream, CopyAndMoveTest) { ASSERT_EQ_CUDA(moveStream.stream(), cuda_stream); } +// Verifies stream priority is handled properly +TEST(TestStream, StreamPriorityTest) { + if (!at::cuda::is_available()) return; + auto [least_priority, greatest_priority] = + at::cuda::CUDAStream::priority_range(); + EXPECT_EQ(least_priority, 0); + + auto stream = at::cuda::getStreamFromPool(-1); + EXPECT_EQ(stream.priority(), -1); + EXPECT_GT(10, at::cuda::max_compile_time_stream_priorities); + stream = at::cuda::getStreamFromPool(-10); + EXPECT_EQ(stream.priority(), greatest_priority); + stream = at::cuda::getStreamFromPool(0); + EXPECT_EQ(stream.priority(), 0); + stream = at::cuda::getStreamFromPool(10); + EXPECT_EQ(stream.priority(), 0); +} + // Verifies streams are set properly TEST(TestStream, GetAndSetTest) { if (!at::cuda::is_available()) return; diff --git a/c10/cuda/CUDAStream.cpp b/c10/cuda/CUDAStream.cpp index 23ab4a7f6edc..bbaeeba84ddc 100644 --- a/c10/cuda/CUDAStream.cpp +++ b/c10/cuda/CUDAStream.cpp @@ -318,10 +318,6 @@ CUDAStream getStreamFromPool(const int priority, DeviceIndex device_index) { device_index = current_device(); c10::cuda::SetTargetDevice(); } - TORCH_CHECK( - priority <= 0, - "Expected cuda stream priority to be less than or equal to 0, got ", - priority); check_gpu(device_index); #if !defined(USE_ROCM) // See Note [HIP Lazy Streams] @@ -329,9 +325,7 @@ CUDAStream getStreamFromPool(const int priority, DeviceIndex device_index) { c10::call_once( device_flags[device_index], initDeviceStreamState, device_index); #endif - auto pri_idx = -priority; - pri_idx = - std::min(pri_idx, max_stream_priorities - 1); // pri_idx is zero-based + auto pri_idx = std::clamp(-priority, 0, max_stream_priorities - 1); const auto idx = get_idx(priority_counters[pri_idx][device_index]); StreamIdType id_type = StreamIdType(pri_idx + 1); return CUDAStreamForId(device_index, makeStreamId(id_type, idx)); diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index 6ef0baeeaf4e..cefe9caf9cda 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -22,9 +22,11 @@ class Stream(torch._C._CudaStreamBase): device(torch.device or int, optional): a device on which to allocate the stream. If :attr:`device` is ``None`` (default) or a negative integer, this will use the current device. - priority(int, optional): priority of the stream, should be 0 or - negative, where negative numbers indicate higher priority. By default, - streams have priority 0. + priority(int, optional): priority of the stream, which can be positive, 0, or negative. + A lower number indicates a higher priority. By default, the priority is set to 0. + If the value falls outside of the allowed priority range, it will automatically be + mapped to the nearest valid priority (lowest for large positive numbers or + highest for large negative numbers). """