Files
pytorch/test/cpp/nativert/test_mpmc_queue.cpp
Zhengxu Chen 5bb154e6fd [nativert] Move MPMCQueue to torch/nativert. (#152837)
Summary:
Torch Native Runtime RFC: https://github.com/zhxchen17/rfcs/blob/master/RFC-0043-torch-native-runtime.md

To land the runtime into PyTorch core, we will gradually land logical parts of the code into the Github issue and get each piece properly reviewed.

This diff adds a small library implementing a multi producer multi consumer queue which will be used to synchronize taks for Torch Native Runtime.

Differential Revision: D74184245

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152837
Approved by: https://github.com/albanD, https://github.com/dolpm, https://github.com/swolchok
2025-05-07 21:17:42 +00:00

122 lines
3.0 KiB
C++

#include <atomic>
#include <thread>
#include <gtest/gtest.h>
#include <torch/nativert/detail/MPMCQueue.h>
using torch::nativert::detail::MPMCQueue;
TEST(MPMCQueueTest, EmptyQueue) {
MPMCQueue<int> queue(5);
int out = 0;
EXPECT_FALSE(queue.readIfNotEmpty(out));
}
TEST(MPMCQueueTest, SingleElement) {
MPMCQueue<int> queue(5);
EXPECT_TRUE(queue.writeIfNotFull(10));
int out = 0;
EXPECT_TRUE(queue.readIfNotEmpty(out));
EXPECT_EQ(out, 10);
}
TEST(MPMCQueueTest, MultipleElements) {
MPMCQueue<int> queue(5);
for (int i = 0; i < 5; ++i) {
EXPECT_TRUE(queue.writeIfNotFull(i));
}
for (int i = 0; i < 5; ++i) {
int out = 0;
EXPECT_TRUE(queue.readIfNotEmpty(out));
EXPECT_EQ(out, i);
}
}
TEST(MPMCQueueTest, FullQueue) {
MPMCQueue<int> queue(5);
for (int i = 0; i < 5; ++i) {
EXPECT_TRUE(queue.writeIfNotFull(i));
}
EXPECT_FALSE(queue.writeIfNotFull(10));
}
TEST(MPMCQueueTest, ConcurrentAccess) {
MPMCQueue<int> queue(10);
std::thread writer([&queue]() {
for (int i = 0; i < 5; ++i) {
queue.writeIfNotFull(i);
}
});
std::thread reader([&queue]() {
for (int i = 0; i < 5; ++i) {
int out = 0;
while (!queue.readIfNotEmpty(out)) {
// Wait until an element is available
}
EXPECT_LT(out, 5);
}
});
writer.join();
reader.join();
}
TEST(MPMCQueueTest, MPMCConcurrentAccess) {
const size_t queueCapacity = 100000;
const size_t numWriters = 5;
const size_t numReaders = 5;
const size_t numElementsPerWriter = 10000;
MPMCQueue<int> queue(queueCapacity);
// Writer threads
std::vector<std::thread> writers;
writers.reserve(numWriters);
for (size_t i = 0; i < numWriters; ++i) {
writers.emplace_back([&]() {
for (size_t j = 0; j < numElementsPerWriter; ++j) {
size_t value = i * numElementsPerWriter + j;
while (!queue.writeIfNotFull(static_cast<int>(value))) {
// Retry until the queue has space
}
}
});
}
// Reader threads
std::vector<std::thread> readers;
std::atomic<size_t> totalReadCount{0};
readers.reserve(numReaders);
for (size_t i = 0; i < numReaders; ++i) {
readers.emplace_back([&]() {
int value = 0;
while (totalReadCount < numWriters * numElementsPerWriter) {
if (queue.readIfNotEmpty(value)) {
++totalReadCount;
}
}
});
}
// Join all threads
for (auto& writer : writers) {
writer.join();
}
for (auto& reader : readers) {
reader.join();
}
// Verify that all elements were read
EXPECT_EQ(totalReadCount, numWriters * numElementsPerWriter);
}
TEST(MPMCQueueTest, MoveOnlyType) {
struct MoveOnly {
MoveOnly() = default;
MoveOnly(const MoveOnly&) = delete;
MoveOnly& operator=(const MoveOnly&) = delete;
MoveOnly(MoveOnly&&) = default;
MoveOnly& operator=(MoveOnly&&) = default;
~MoveOnly() = default;
};
MPMCQueue<MoveOnly> queue(5);
EXPECT_TRUE(queue.writeIfNotFull(MoveOnly()));
MoveOnly out;
EXPECT_TRUE(queue.readIfNotEmpty(out));
}