mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nativert] Move MPMCQueue to torch/nativert. (#152837)
Summary: Torch Native Runtime RFC: https://github.com/zhxchen17/rfcs/blob/master/RFC-0043-torch-native-runtime.md To land the runtime into PyTorch core, we will gradually land logical parts of the code into the Github issue and get each piece properly reviewed. This diff adds a small library implementing a multi producer multi consumer queue which will be used to synchronize taks for Torch Native Runtime. Differential Revision: D74184245 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152837 Approved by: https://github.com/albanD, https://github.com/dolpm, https://github.com/swolchok
This commit is contained in:
committed by
PyTorch MergeBot
parent
d2ee606e9b
commit
5bb154e6fd
@ -3,6 +3,7 @@ set(NATIVERT_TEST_ROOT ${TORCH_ROOT}/test/cpp/nativert)
|
||||
# Build the cpp gtest binary containing the cpp-only tests.
|
||||
set(NATIVERT_TEST_SRCS
|
||||
${NATIVERT_TEST_ROOT}/test_tensor_meta.cpp
|
||||
${NATIVERT_TEST_ROOT}/test_mpmc_queue.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/TensorMeta.cpp
|
||||
)
|
||||
|
||||
|
121
test/cpp/nativert/test_mpmc_queue.cpp
Normal file
121
test/cpp/nativert/test_mpmc_queue.cpp
Normal file
@ -0,0 +1,121 @@
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/nativert/detail/MPMCQueue.h>
|
||||
|
||||
using torch::nativert::detail::MPMCQueue;
|
||||
|
||||
TEST(MPMCQueueTest, EmptyQueue) {
|
||||
MPMCQueue<int> queue(5);
|
||||
int out = 0;
|
||||
EXPECT_FALSE(queue.readIfNotEmpty(out));
|
||||
}
|
||||
|
||||
TEST(MPMCQueueTest, SingleElement) {
|
||||
MPMCQueue<int> queue(5);
|
||||
EXPECT_TRUE(queue.writeIfNotFull(10));
|
||||
int out = 0;
|
||||
EXPECT_TRUE(queue.readIfNotEmpty(out));
|
||||
EXPECT_EQ(out, 10);
|
||||
}
|
||||
|
||||
TEST(MPMCQueueTest, MultipleElements) {
|
||||
MPMCQueue<int> queue(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
EXPECT_TRUE(queue.writeIfNotFull(i));
|
||||
}
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
int out = 0;
|
||||
EXPECT_TRUE(queue.readIfNotEmpty(out));
|
||||
EXPECT_EQ(out, i);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MPMCQueueTest, FullQueue) {
|
||||
MPMCQueue<int> queue(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
EXPECT_TRUE(queue.writeIfNotFull(i));
|
||||
}
|
||||
EXPECT_FALSE(queue.writeIfNotFull(10));
|
||||
}
|
||||
|
||||
TEST(MPMCQueueTest, ConcurrentAccess) {
|
||||
MPMCQueue<int> queue(10);
|
||||
std::thread writer([&queue]() {
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
queue.writeIfNotFull(i);
|
||||
}
|
||||
});
|
||||
std::thread reader([&queue]() {
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
int out = 0;
|
||||
while (!queue.readIfNotEmpty(out)) {
|
||||
// Wait until an element is available
|
||||
}
|
||||
EXPECT_LT(out, 5);
|
||||
}
|
||||
});
|
||||
writer.join();
|
||||
reader.join();
|
||||
}
|
||||
|
||||
TEST(MPMCQueueTest, MPMCConcurrentAccess) {
|
||||
const size_t queueCapacity = 100000;
|
||||
const size_t numWriters = 5;
|
||||
const size_t numReaders = 5;
|
||||
const size_t numElementsPerWriter = 10000;
|
||||
MPMCQueue<int> queue(queueCapacity);
|
||||
// Writer threads
|
||||
std::vector<std::thread> writers;
|
||||
writers.reserve(numWriters);
|
||||
for (size_t i = 0; i < numWriters; ++i) {
|
||||
writers.emplace_back([&]() {
|
||||
for (size_t j = 0; j < numElementsPerWriter; ++j) {
|
||||
size_t value = i * numElementsPerWriter + j;
|
||||
while (!queue.writeIfNotFull(static_cast<int>(value))) {
|
||||
// Retry until the queue has space
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
// Reader threads
|
||||
std::vector<std::thread> readers;
|
||||
std::atomic<size_t> totalReadCount{0};
|
||||
readers.reserve(numReaders);
|
||||
for (size_t i = 0; i < numReaders; ++i) {
|
||||
readers.emplace_back([&]() {
|
||||
int value = 0;
|
||||
while (totalReadCount < numWriters * numElementsPerWriter) {
|
||||
if (queue.readIfNotEmpty(value)) {
|
||||
++totalReadCount;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
// Join all threads
|
||||
for (auto& writer : writers) {
|
||||
writer.join();
|
||||
}
|
||||
for (auto& reader : readers) {
|
||||
reader.join();
|
||||
}
|
||||
// Verify that all elements were read
|
||||
EXPECT_EQ(totalReadCount, numWriters * numElementsPerWriter);
|
||||
}
|
||||
|
||||
TEST(MPMCQueueTest, MoveOnlyType) {
|
||||
struct MoveOnly {
|
||||
MoveOnly() = default;
|
||||
MoveOnly(const MoveOnly&) = delete;
|
||||
MoveOnly& operator=(const MoveOnly&) = delete;
|
||||
MoveOnly(MoveOnly&&) = default;
|
||||
MoveOnly& operator=(MoveOnly&&) = default;
|
||||
~MoveOnly() = default;
|
||||
};
|
||||
MPMCQueue<MoveOnly> queue(5);
|
||||
EXPECT_TRUE(queue.writeIfNotFull(MoveOnly()));
|
||||
MoveOnly out;
|
||||
EXPECT_TRUE(queue.readIfNotEmpty(out));
|
||||
}
|
63
torch/nativert/detail/MPMCQueue.h
Normal file
63
torch/nativert/detail/MPMCQueue.h
Normal file
@ -0,0 +1,63 @@
|
||||
/*
|
||||
* A simple thread-safe multi-producer, multi-consumer queue.
|
||||
*
|
||||
* This is a wrapper around std::deque that provides non-blocking
|
||||
* queue operations like readIfNotEmpty and writeIfNotFull using
|
||||
* std mutexes and the underlying queue can only be accessed
|
||||
* with synchronized sections.
|
||||
*
|
||||
* For now the goal is to provide a simple implementation that
|
||||
* works in all cases and produces no surprises to users.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <deque>
|
||||
#include <mutex>
|
||||
#include <type_traits>
|
||||
|
||||
namespace torch::nativert::detail {
|
||||
|
||||
// TODO(zhxchen17) Add wrapper for concurrentqueue.
|
||||
template <typename T>
|
||||
class MPMCQueue {
|
||||
static_assert(!std::is_reference_v<T>);
|
||||
|
||||
public:
|
||||
explicit MPMCQueue(size_t capacity) : capacity_(capacity) {}
|
||||
|
||||
/**
|
||||
* Read from the queue if it is not empty.
|
||||
* @param out The value to read into.
|
||||
* @return true if the read succeeded, false if the queue is empty.
|
||||
*/
|
||||
bool readIfNotEmpty(T& out) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (storage_.empty()) {
|
||||
return false;
|
||||
}
|
||||
out = std::move(storage_.front());
|
||||
storage_.pop_front();
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Write to the queue if it is not full.
|
||||
* @param in The value to write. For now we only support moveable types.
|
||||
* @return true if the write succeeded, false if the queue is full.
|
||||
*/
|
||||
bool writeIfNotFull(T in) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (storage_.size() == capacity_) {
|
||||
return false;
|
||||
}
|
||||
storage_.push_back(std::move(in));
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex mutex_;
|
||||
std::deque<T> storage_;
|
||||
size_t capacity_;
|
||||
};
|
||||
} // namespace torch::nativert::detail
|
Reference in New Issue
Block a user