mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This reverts commit 668082718aefce95ecc1b1c312ea6f127b2c662e. Reverted https://github.com/pytorch/pytorch/pull/85780 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but it breaks build due to a missing file <c10d/Store.hpp>
395 lines
14 KiB
C++
395 lines
14 KiB
C++
#include <c10/util/irange.h>
|
|
#include "StoreTestCommon.hpp"
|
|
|
|
#include <cstdlib>
|
|
#include <future>
|
|
#include <iostream>
|
|
#include <system_error>
|
|
#include <thread>
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <c10d/PrefixStore.hpp>
|
|
#include <c10d/TCPStore.hpp>
|
|
|
|
constexpr int64_t kShortStoreTimeoutMillis = 100;
|
|
constexpr int64_t kStoreCallbackTimeoutMillis = 5000;
|
|
constexpr int defaultTimeout = 20;
|
|
|
|
c10::intrusive_ptr<c10d::TCPStore> _createServer(
|
|
int numWorkers = 1,
|
|
int timeout = defaultTimeout) {
|
|
return c10::make_intrusive<c10d::TCPStore>(
|
|
"127.0.0.1",
|
|
c10d::TCPStoreOptions{
|
|
/* port */ 0,
|
|
/* isServer */ true,
|
|
numWorkers,
|
|
/* waitWorkers */ false,
|
|
/* timeout */ std::chrono::seconds(timeout)});
|
|
}
|
|
|
|
// Different ports for different tests.
|
|
void testHelper(const std::string& prefix = "") {
|
|
constexpr auto numThreads = 16;
|
|
constexpr auto numWorkers = numThreads + 1;
|
|
|
|
auto serverTCPStore = _createServer(numWorkers);
|
|
|
|
auto serverStore =
|
|
c10::make_intrusive<c10d::PrefixStore>(prefix, serverTCPStore);
|
|
// server store
|
|
auto serverThread = std::thread([&serverStore, &serverTCPStore] {
|
|
// Wait for all workers to join.
|
|
serverTCPStore->waitForWorkers();
|
|
|
|
// Basic set/get on the server store
|
|
c10d::test::set(*serverStore, "key0", "value0");
|
|
c10d::test::set(*serverStore, "key1", "value1");
|
|
c10d::test::set(*serverStore, "key2", "value2");
|
|
c10d::test::check(*serverStore, "key0", "value0");
|
|
c10d::test::check(*serverStore, "key1", "value1");
|
|
c10d::test::check(*serverStore, "key2", "value2");
|
|
serverStore->add("counter", 1);
|
|
auto numKeys = serverStore->getNumKeys();
|
|
// We expect 5 keys since 3 are added above, 'counter' is added by the
|
|
// helper thread, and the init key to coordinate workers.
|
|
EXPECT_EQ(numKeys, 5);
|
|
|
|
// Check compareSet, does not check return value
|
|
c10d::test::compareSet(
|
|
*serverStore, "key0", "wrongExpectedValue", "newValue");
|
|
c10d::test::check(*serverStore, "key0", "value0");
|
|
c10d::test::compareSet(*serverStore, "key0", "value0", "newValue");
|
|
c10d::test::check(*serverStore, "key0", "newValue");
|
|
|
|
auto delSuccess = serverStore->deleteKey("key0");
|
|
// Ensure that the key was successfully deleted
|
|
EXPECT_TRUE(delSuccess);
|
|
auto delFailure = serverStore->deleteKey("badKeyName");
|
|
// The key was not in the store so the delete operation should have failed
|
|
// and returned false.
|
|
EXPECT_FALSE(delFailure);
|
|
numKeys = serverStore->getNumKeys();
|
|
EXPECT_EQ(numKeys, 4);
|
|
auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis);
|
|
serverStore->setTimeout(timeout);
|
|
EXPECT_THROW(serverStore->get("key0"), c10::Error);
|
|
});
|
|
|
|
// Hammer on TCPStore
|
|
std::vector<std::thread> threads;
|
|
constexpr auto numIterations = 1000;
|
|
c10d::test::Semaphore sem1, sem2;
|
|
|
|
c10d::TCPStoreOptions opts{};
|
|
opts.port = serverTCPStore->getPort();
|
|
opts.numWorkers = numWorkers;
|
|
|
|
// Each thread will have a client store to send/recv data
|
|
std::vector<c10::intrusive_ptr<c10d::TCPStore>> clientTCPStores;
|
|
std::vector<c10::intrusive_ptr<c10d::PrefixStore>> clientStores;
|
|
for (const auto i : c10::irange(numThreads)) {
|
|
clientTCPStores.push_back(
|
|
c10::make_intrusive<c10d::TCPStore>("127.0.0.1", opts));
|
|
clientStores.push_back(
|
|
c10::make_intrusive<c10d::PrefixStore>(prefix, clientTCPStores[i]));
|
|
}
|
|
|
|
std::string expectedCounterRes =
|
|
std::to_string(numThreads * numIterations + 1);
|
|
|
|
for (const auto i : c10::irange(numThreads)) {
|
|
threads.emplace_back(
|
|
std::thread([=, &sem1, &sem2, &clientStores, &expectedCounterRes] {
|
|
for (C10_UNUSED const auto j : c10::irange(numIterations)) {
|
|
clientStores[i]->add("counter", 1);
|
|
}
|
|
// Let each thread set and get key on its client store
|
|
std::string key = "thread_" + std::to_string(i);
|
|
for (const auto j : c10::irange(numIterations)) {
|
|
std::string val = "thread_val_" + std::to_string(j);
|
|
c10d::test::set(*clientStores[i], key, val);
|
|
c10d::test::check(*clientStores[i], key, val);
|
|
}
|
|
|
|
sem1.post();
|
|
sem2.wait();
|
|
// Check the counter results
|
|
c10d::test::check(*clientStores[i], "counter", expectedCounterRes);
|
|
// Now check other threads' written data
|
|
for (const auto j : c10::irange(numThreads)) {
|
|
if (j == i) {
|
|
continue;
|
|
}
|
|
std::string key = "thread_" + std::to_string(i);
|
|
std::string val = "thread_val_" + std::to_string(numIterations - 1);
|
|
c10d::test::check(*clientStores[i], key, val);
|
|
}
|
|
}));
|
|
}
|
|
|
|
sem1.wait(numThreads);
|
|
sem2.post(numThreads);
|
|
|
|
for (auto& thread : threads) {
|
|
thread.join();
|
|
}
|
|
|
|
serverThread.join();
|
|
|
|
// Clear the store to test that client disconnect won't shutdown the store
|
|
clientStores.clear();
|
|
clientTCPStores.clear();
|
|
|
|
// Check that the counter has the expected value
|
|
c10d::test::check(*serverStore, "counter", expectedCounterRes);
|
|
|
|
// Check that each threads' written data from the main thread
|
|
for (const auto i : c10::irange(numThreads)) {
|
|
std::string key = "thread_" + std::to_string(i);
|
|
std::string val = "thread_val_" + std::to_string(numIterations - 1);
|
|
c10d::test::check(*serverStore, key, val);
|
|
}
|
|
}
|
|
|
|
void testWatchKeyCallback(const std::string& prefix = "") {
|
|
// Callback function increments counter of the total number of callbacks that
|
|
// were run
|
|
std::promise<int> numCallbacksExecutedPromise;
|
|
std::atomic<int> numCallbacksExecuted{0};
|
|
constexpr int numThreads = 16;
|
|
constexpr int keyChangeOperation = 3;
|
|
c10d::WatchKeyCallback callback =
|
|
[=, &numCallbacksExecuted, &numCallbacksExecutedPromise](
|
|
c10::optional<std::string> /* unused */,
|
|
c10::optional<std::string> /* unused */) {
|
|
numCallbacksExecuted++;
|
|
if (numCallbacksExecuted == numThreads * keyChangeOperation * 2) {
|
|
numCallbacksExecutedPromise.set_value(numCallbacksExecuted);
|
|
}
|
|
};
|
|
|
|
const int numWorkers = numThreads + 1;
|
|
auto serverTCPStore = _createServer(numWorkers);
|
|
auto serverStore =
|
|
c10::make_intrusive<c10d::PrefixStore>(prefix, serverTCPStore);
|
|
|
|
c10d::TCPStoreOptions opts{};
|
|
opts.port = serverTCPStore->getPort();
|
|
opts.numWorkers = numWorkers;
|
|
|
|
// Each thread will have a client store to send/recv data
|
|
std::vector<c10::intrusive_ptr<c10d::TCPStore>> clientTCPStores;
|
|
std::vector<c10::intrusive_ptr<c10d::PrefixStore>> clientStores;
|
|
for (const auto i : c10::irange(numThreads)) {
|
|
clientTCPStores.push_back(
|
|
c10::make_intrusive<c10d::TCPStore>("127.0.0.1", opts));
|
|
clientStores.push_back(
|
|
c10::make_intrusive<c10d::PrefixStore>(prefix, clientTCPStores[i]));
|
|
}
|
|
|
|
// Start watching key on server and client stores
|
|
std::string internalKey = "internalKey";
|
|
std::string internalKeyCount = "internalKeyCount";
|
|
for (const auto i : c10::irange(numThreads)) {
|
|
serverStore->watchKey(internalKey + std::to_string(i), callback);
|
|
serverStore->watchKey(internalKeyCount + std::to_string(i), callback);
|
|
clientStores[i]->watchKey(internalKey + std::to_string(i), callback);
|
|
clientStores[i]->watchKey(internalKeyCount + std::to_string(i), callback);
|
|
}
|
|
|
|
std::vector<std::thread> threads;
|
|
std::atomic<int> keyChangeOperationCount{0};
|
|
for (const auto i : c10::irange(numThreads)) {
|
|
threads.emplace_back(std::thread([=,
|
|
&clientStores,
|
|
&internalKey,
|
|
&internalKeyCount,
|
|
&keyChangeOperationCount] {
|
|
// Let each thread set and get key on its client store
|
|
std::string key = internalKey + std::to_string(i);
|
|
std::string keyCounter = internalKeyCount + std::to_string(i);
|
|
std::string val = "thread_val_" + std::to_string(i);
|
|
// The set, compareSet, add methods count as key change operations
|
|
c10d::test::set(*clientStores[i], key, val);
|
|
c10d::test::compareSet(*clientStores[i], key, val, "newValue");
|
|
clientStores[i]->add(keyCounter, i);
|
|
keyChangeOperationCount += keyChangeOperation * 2;
|
|
c10d::test::check(*clientStores[i], key, "newValue");
|
|
c10d::test::check(*clientStores[i], keyCounter, std::to_string(i));
|
|
}));
|
|
}
|
|
|
|
// Ensures that internal_key has been "set" and "get"
|
|
for (auto& thread : threads) {
|
|
thread.join();
|
|
}
|
|
|
|
std::future<int> numCallbacksExecutedFuture =
|
|
numCallbacksExecutedPromise.get_future();
|
|
std::chrono::milliseconds span(kStoreCallbackTimeoutMillis);
|
|
if (numCallbacksExecutedFuture.wait_for(span) == std::future_status::timeout)
|
|
TORCH_CHECK(false, "Callback execution timed out.");
|
|
|
|
// Check number of callbacks executed equal to number of key change operations
|
|
// Wait for all callbacks to be triggered
|
|
EXPECT_EQ(keyChangeOperationCount, numCallbacksExecutedFuture.get());
|
|
}
|
|
|
|
TEST(TCPStoreTest, testHelper) {
|
|
testHelper();
|
|
}
|
|
|
|
TEST(TCPStoreTest, testHelperPrefix) {
|
|
testHelper("testPrefix");
|
|
}
|
|
|
|
TEST(TCPStoreTest, testWatchKeyCallback) {
|
|
testWatchKeyCallback();
|
|
}
|
|
|
|
TEST(TCPStoreTest, testWatchKeyCallbackWithPrefix) {
|
|
testWatchKeyCallback("testPrefix");
|
|
}
|
|
|
|
// Helper function to create a key on the store, watch it, and run the callback
|
|
void testKeyChangeHelper(
|
|
c10d::Store& store,
|
|
std::string key,
|
|
const c10::optional<std::string>& expectedOldValue,
|
|
const c10::optional<std::string>& expectedNewValue) {
|
|
std::exception_ptr eptr = nullptr;
|
|
std::promise<bool> callbackPromise;
|
|
|
|
// Test the correctness of new_value and old_value
|
|
c10d::WatchKeyCallback callback = [expectedOldValue,
|
|
expectedNewValue,
|
|
&callbackPromise,
|
|
&eptr](
|
|
c10::optional<std::string> oldValue,
|
|
c10::optional<std::string> newValue) {
|
|
try {
|
|
EXPECT_EQ(expectedOldValue.value_or("NONE"), oldValue.value_or("NONE"));
|
|
EXPECT_EQ(expectedNewValue.value_or("NONE"), newValue.value_or("NONE"));
|
|
} catch (...) {
|
|
eptr = std::current_exception();
|
|
}
|
|
callbackPromise.set_value(true);
|
|
};
|
|
store.watchKey(key, callback);
|
|
|
|
// Perform the specified update according to key
|
|
if (key == "testEmptyKeyValue" || key == "testRegularKeyValue" ||
|
|
key == "testWatchKeyCreate") {
|
|
c10d::test::set(store, key, expectedNewValue.value());
|
|
} else if (key == "testWatchKeyAdd") {
|
|
store.add(key, std::stoi(expectedNewValue.value()));
|
|
} else if (key == "testWatchKeyDelete") {
|
|
store.deleteKey(key);
|
|
}
|
|
|
|
// Test that the callback is fired and the expected values are correct
|
|
std::future<bool> callbackFuture = callbackPromise.get_future();
|
|
std::chrono::milliseconds span(kStoreCallbackTimeoutMillis);
|
|
if (callbackFuture.wait_for(span) == std::future_status::timeout)
|
|
TORCH_CHECK(false, "Callback execution timed out.");
|
|
|
|
// Any exceptions raised from asserts should be rethrown
|
|
if (eptr)
|
|
std::rethrow_exception(eptr);
|
|
}
|
|
|
|
TEST(TCPStoreTest, testKeyEmptyUpdate) {
|
|
auto store = _createServer();
|
|
|
|
std::string key = "testEmptyKeyValue";
|
|
c10d::test::set(*store, key, "");
|
|
store->get(key);
|
|
testKeyChangeHelper(*store, key, "", "2");
|
|
}
|
|
|
|
TEST(TCPStoreTest, testKeyUpdate) {
|
|
auto store = _createServer();
|
|
|
|
std::string key = "testRegularKeyValue";
|
|
c10d::test::set(*store, key, "1");
|
|
store->get(key);
|
|
testKeyChangeHelper(*store, key, "1", "2");
|
|
}
|
|
|
|
TEST(TCPStoreTest, testKeyCreate) {
|
|
auto store = _createServer();
|
|
|
|
std::string key = "testWatchKeyCreate";
|
|
testKeyChangeHelper(*store, key, c10::nullopt, "2");
|
|
}
|
|
|
|
TEST(TCPStoreTest, testKeyAdd) {
|
|
auto store = _createServer();
|
|
|
|
std::string key = "testWatchKeyAdd";
|
|
testKeyChangeHelper(*store, key, c10::nullopt, "2");
|
|
}
|
|
|
|
TEST(TCPStoreTest, testKeyDelete) {
|
|
auto store = _createServer();
|
|
|
|
std::string key = "testWatchKeyDelete";
|
|
c10d::test::set(*store, key, "1");
|
|
store->get(key);
|
|
testKeyChangeHelper(*store, key, "1", c10::nullopt);
|
|
}
|
|
|
|
TEST(TCPStoreTest, testCleanShutdown) {
|
|
int numWorkers = 2;
|
|
|
|
auto serverTCPStore = std::make_unique<c10d::TCPStore>(
|
|
"127.0.0.1",
|
|
0,
|
|
numWorkers,
|
|
true,
|
|
std::chrono::seconds(defaultTimeout),
|
|
/* wait */ false);
|
|
c10d::test::set(*serverTCPStore, "key", "val");
|
|
|
|
auto clientTCPStore = c10::make_intrusive<c10d::TCPStore>(
|
|
"127.0.0.1",
|
|
c10d::TCPStoreOptions{
|
|
/* port */ serverTCPStore->getPort(),
|
|
/* isServer */ false,
|
|
numWorkers,
|
|
/* waitWorkers */ false,
|
|
/* timeout */ std::chrono::seconds(defaultTimeout)});
|
|
clientTCPStore->get("key");
|
|
|
|
auto clientThread = std::thread([&clientTCPStore] {
|
|
EXPECT_THROW(clientTCPStore->get("invalid_key"), std::system_error);
|
|
});
|
|
|
|
// start server shutdown during a client request
|
|
serverTCPStore = nullptr;
|
|
|
|
clientThread.join();
|
|
}
|
|
|
|
TEST(TCPStoreTest, testMultiTenantStores) {
|
|
c10d::TCPStoreOptions opts{};
|
|
opts.isServer = true;
|
|
opts.multiTenant = true;
|
|
|
|
// Construct two server stores on the same port.
|
|
auto store1 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
|
|
auto store2 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
|
|
|
|
// Assert that the two stores share the same server.
|
|
c10d::test::set(*store1, "key0", "value0");
|
|
c10d::test::check(*store2, "key0", "value0");
|
|
|
|
// Dispose the second instance and assert that the server is still alive.
|
|
store2.reset();
|
|
|
|
c10d::test::set(*store1, "key0", "value0");
|
|
c10d::test::check(*store1, "key0", "value0");
|
|
}
|