[nativert] Move MPMCQueue to torch/nativert. (#152837)

Summary:
Torch Native Runtime RFC: https://github.com/zhxchen17/rfcs/blob/master/RFC-0043-torch-native-runtime.md

To land the runtime into PyTorch core, we will gradually land logical parts of the code into the Github issue and get each piece properly reviewed.

This diff adds a small library implementing a multi producer multi consumer queue which will be used to synchronize taks for Torch Native Runtime.

Differential Revision: D74184245

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152837
Approved by: https://github.com/albanD, https://github.com/dolpm, https://github.com/swolchok
This commit is contained in:
Zhengxu Chen
2025-05-07 21:17:42 +00:00
committed by PyTorch MergeBot
parent d2ee606e9b
commit 5bb154e6fd
3 changed files with 185 additions and 0 deletions

View File

@ -3,6 +3,7 @@ set(NATIVERT_TEST_ROOT ${TORCH_ROOT}/test/cpp/nativert)
# Build the cpp gtest binary containing the cpp-only tests.
set(NATIVERT_TEST_SRCS
${NATIVERT_TEST_ROOT}/test_tensor_meta.cpp
${NATIVERT_TEST_ROOT}/test_mpmc_queue.cpp
${TORCH_ROOT}/torch/nativert/graph/TensorMeta.cpp
)

View File

@ -0,0 +1,121 @@
#include <atomic>
#include <thread>
#include <gtest/gtest.h>
#include <torch/nativert/detail/MPMCQueue.h>
using torch::nativert::detail::MPMCQueue;
TEST(MPMCQueueTest, EmptyQueue) {
MPMCQueue<int> queue(5);
int out = 0;
EXPECT_FALSE(queue.readIfNotEmpty(out));
}
TEST(MPMCQueueTest, SingleElement) {
MPMCQueue<int> queue(5);
EXPECT_TRUE(queue.writeIfNotFull(10));
int out = 0;
EXPECT_TRUE(queue.readIfNotEmpty(out));
EXPECT_EQ(out, 10);
}
TEST(MPMCQueueTest, MultipleElements) {
MPMCQueue<int> queue(5);
for (int i = 0; i < 5; ++i) {
EXPECT_TRUE(queue.writeIfNotFull(i));
}
for (int i = 0; i < 5; ++i) {
int out = 0;
EXPECT_TRUE(queue.readIfNotEmpty(out));
EXPECT_EQ(out, i);
}
}
TEST(MPMCQueueTest, FullQueue) {
MPMCQueue<int> queue(5);
for (int i = 0; i < 5; ++i) {
EXPECT_TRUE(queue.writeIfNotFull(i));
}
EXPECT_FALSE(queue.writeIfNotFull(10));
}
TEST(MPMCQueueTest, ConcurrentAccess) {
MPMCQueue<int> queue(10);
std::thread writer([&queue]() {
for (int i = 0; i < 5; ++i) {
queue.writeIfNotFull(i);
}
});
std::thread reader([&queue]() {
for (int i = 0; i < 5; ++i) {
int out = 0;
while (!queue.readIfNotEmpty(out)) {
// Wait until an element is available
}
EXPECT_LT(out, 5);
}
});
writer.join();
reader.join();
}
TEST(MPMCQueueTest, MPMCConcurrentAccess) {
const size_t queueCapacity = 100000;
const size_t numWriters = 5;
const size_t numReaders = 5;
const size_t numElementsPerWriter = 10000;
MPMCQueue<int> queue(queueCapacity);
// Writer threads
std::vector<std::thread> writers;
writers.reserve(numWriters);
for (size_t i = 0; i < numWriters; ++i) {
writers.emplace_back([&]() {
for (size_t j = 0; j < numElementsPerWriter; ++j) {
size_t value = i * numElementsPerWriter + j;
while (!queue.writeIfNotFull(static_cast<int>(value))) {
// Retry until the queue has space
}
}
});
}
// Reader threads
std::vector<std::thread> readers;
std::atomic<size_t> totalReadCount{0};
readers.reserve(numReaders);
for (size_t i = 0; i < numReaders; ++i) {
readers.emplace_back([&]() {
int value = 0;
while (totalReadCount < numWriters * numElementsPerWriter) {
if (queue.readIfNotEmpty(value)) {
++totalReadCount;
}
}
});
}
// Join all threads
for (auto& writer : writers) {
writer.join();
}
for (auto& reader : readers) {
reader.join();
}
// Verify that all elements were read
EXPECT_EQ(totalReadCount, numWriters * numElementsPerWriter);
}
TEST(MPMCQueueTest, MoveOnlyType) {
struct MoveOnly {
MoveOnly() = default;
MoveOnly(const MoveOnly&) = delete;
MoveOnly& operator=(const MoveOnly&) = delete;
MoveOnly(MoveOnly&&) = default;
MoveOnly& operator=(MoveOnly&&) = default;
~MoveOnly() = default;
};
MPMCQueue<MoveOnly> queue(5);
EXPECT_TRUE(queue.writeIfNotFull(MoveOnly()));
MoveOnly out;
EXPECT_TRUE(queue.readIfNotEmpty(out));
}

View File

@ -0,0 +1,63 @@
/*
* A simple thread-safe multi-producer, multi-consumer queue.
*
* This is a wrapper around std::deque that provides non-blocking
* queue operations like readIfNotEmpty and writeIfNotFull using
* std mutexes and the underlying queue can only be accessed
* with synchronized sections.
*
* For now the goal is to provide a simple implementation that
* works in all cases and produces no surprises to users.
*/
#pragma once
#include <deque>
#include <mutex>
#include <type_traits>
namespace torch::nativert::detail {
// TODO(zhxchen17) Add wrapper for concurrentqueue.
template <typename T>
class MPMCQueue {
static_assert(!std::is_reference_v<T>);
public:
explicit MPMCQueue(size_t capacity) : capacity_(capacity) {}
/**
* Read from the queue if it is not empty.
* @param out The value to read into.
* @return true if the read succeeded, false if the queue is empty.
*/
bool readIfNotEmpty(T& out) {
std::lock_guard<std::mutex> lock(mutex_);
if (storage_.empty()) {
return false;
}
out = std::move(storage_.front());
storage_.pop_front();
return true;
}
/**
* Write to the queue if it is not full.
* @param in The value to write. For now we only support moveable types.
* @return true if the write succeeded, false if the queue is full.
*/
bool writeIfNotFull(T in) {
std::lock_guard<std::mutex> lock(mutex_);
if (storage_.size() == capacity_) {
return false;
}
storage_.push_back(std::move(in));
return true;
}
private:
std::mutex mutex_;
std::deque<T> storage_;
size_t capacity_;
};
} // namespace torch::nativert::detail