mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
Summary: The main thread establishes a dedicated stop signal pipe for each TCPStore background thread. Before joining a background thread, the main thread would close the write end of the corresponding pipe, expecting the background the thread to receive POLLHUP. Upon receiving POLLHUP, the background thread would break the loop and graceful exit. Although we haven't found any documentation or literature backing this, we have evidence that under certain circumstances, the read end of the pipe won't receive POLLUP when the write end is closed. However, under the same circumstances, writing to the pipe will guarantee POLLIN to be received on the read end. Test Plan: Manually tested Differential Revision: D36208897 Pull Request resolved: https://github.com/pytorch/pytorch/pull/76973 Approved by: https://github.com/cbalioglu
1166 lines
34 KiB
C++
1166 lines
34 KiB
C++
#include <c10/util/irange.h>
|
|
#include <c10d/TCPStore.hpp>
|
|
|
|
#include <fcntl.h>
|
|
#include <algorithm>
|
|
#include <array>
|
|
#include <system_error>
|
|
#include <thread>
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
|
|
#ifdef _WIN32
|
|
#include <io.h>
|
|
#include <winsock2.h>
|
|
#else
|
|
#include <poll.h>
|
|
#include <unistd.h>
|
|
#endif
|
|
|
|
#ifdef _WIN32
|
|
#include <c10d/WinSockUtils.hpp>
|
|
#else
|
|
#include <c10d/UnixSockUtils.hpp>
|
|
#endif
|
|
|
|
#include <c10d/socket.h>
|
|
|
|
namespace c10d {
|
|
namespace detail {
|
|
namespace {
|
|
|
|
// Abstract base class to handle thread state for TCPStoreMasterDaemon and
|
|
// TCPStoreWorkerDaemon. Contains the windows/unix implementations to signal a
|
|
// shutdown sequence for the thread
|
|
class BackgroundThread {
|
|
public:
|
|
explicit BackgroundThread(Socket&& storeListenSocket);
|
|
|
|
virtual ~BackgroundThread() = 0;
|
|
|
|
protected:
|
|
void dispose();
|
|
|
|
Socket storeListenSocket_;
|
|
std::thread daemonThread_{};
|
|
std::vector<Socket> sockets_{};
|
|
#ifdef _WIN32
|
|
const std::chrono::milliseconds checkTimeout_ = std::chrono::milliseconds{10};
|
|
HANDLE ghStopEvent_{};
|
|
#else
|
|
std::array<int, 2> controlPipeFd_{{-1, -1}};
|
|
#endif
|
|
|
|
private:
|
|
// Initialization for shutdown signal
|
|
void initStopSignal();
|
|
// Triggers the shutdown signal
|
|
void stop();
|
|
// Joins the thread
|
|
void join();
|
|
// Clean up the shutdown signal
|
|
void closeStopSignal();
|
|
};
|
|
|
|
// Background thread parent class methods
|
|
BackgroundThread::BackgroundThread(Socket&& storeListenSocket)
|
|
: storeListenSocket_{std::move(storeListenSocket)} {
|
|
// Signal instance destruction to the daemon thread.
|
|
initStopSignal();
|
|
}
|
|
|
|
BackgroundThread::~BackgroundThread() = default;
|
|
|
|
// WARNING:
|
|
// Since we rely on the subclass for the daemon thread clean-up, we cannot
|
|
// destruct our member variables in the destructor. The subclass must call
|
|
// dispose() in its own destructor.
|
|
void BackgroundThread::dispose() {
|
|
// Stop the run
|
|
stop();
|
|
// Join the thread
|
|
join();
|
|
// Close unclosed sockets
|
|
sockets_.clear();
|
|
// Now close the rest control pipe
|
|
closeStopSignal();
|
|
}
|
|
|
|
void BackgroundThread::join() {
|
|
daemonThread_.join();
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
void BackgroundThread::initStopSignal() {
|
|
ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL);
|
|
if (ghStopEvent_ == NULL) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Failed to create the control pipe to start the "
|
|
"BackgroundThread run");
|
|
}
|
|
}
|
|
|
|
void BackgroundThread::closeStopSignal() {
|
|
CloseHandle(ghStopEvent_);
|
|
}
|
|
|
|
void BackgroundThread::stop() {
|
|
SetEvent(ghStopEvent_);
|
|
}
|
|
#else
|
|
void BackgroundThread::initStopSignal() {
|
|
if (pipe(controlPipeFd_.data()) == -1) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Failed to create the control pipe to start the "
|
|
"BackgroundThread run");
|
|
}
|
|
}
|
|
|
|
void BackgroundThread::closeStopSignal() {
|
|
for (int fd : controlPipeFd_) {
|
|
if (fd != -1) {
|
|
::close(fd);
|
|
}
|
|
}
|
|
}
|
|
|
|
void BackgroundThread::stop() {
|
|
if (controlPipeFd_[1] != -1) {
|
|
::write(controlPipeFd_[1], "\0", 1);
|
|
// close the write end of the pipe
|
|
::close(controlPipeFd_[1]);
|
|
controlPipeFd_[1] = -1;
|
|
}
|
|
}
|
|
#endif
|
|
|
|
enum class QueryType : uint8_t {
|
|
SET,
|
|
COMPARE_SET,
|
|
GET,
|
|
ADD,
|
|
CHECK,
|
|
WAIT,
|
|
GETNUMKEYS,
|
|
WATCH_KEY,
|
|
DELETE_KEY,
|
|
};
|
|
|
|
enum class CheckResponseType : uint8_t { READY, NOT_READY };
|
|
|
|
enum class WaitResponseType : uint8_t { STOP_WAITING };
|
|
|
|
enum class WatchResponseType : uint8_t {
|
|
KEY_UPDATED,
|
|
KEY_CREATED,
|
|
KEY_DELETED,
|
|
KEY_CALLBACK_REGISTERED
|
|
};
|
|
|
|
// Separate thread that is only launched on master
|
|
class TCPStoreMasterDaemon : public BackgroundThread {
|
|
public:
|
|
explicit TCPStoreMasterDaemon(Socket&& storeListenSocket);
|
|
|
|
~TCPStoreMasterDaemon() override;
|
|
|
|
private:
|
|
void run();
|
|
void queryFds(std::vector<struct pollfd>& fds);
|
|
void query(int socket);
|
|
|
|
// The master runs on a single thread so only
|
|
// one handler can be executed at a time
|
|
void setHandler(int socket);
|
|
void compareSetHandler(int socket);
|
|
void addHandler(int socket);
|
|
void getHandler(int socket) const;
|
|
void checkHandler(int socket) const;
|
|
void getNumKeysHandler(int socket) const;
|
|
void deleteHandler(int socket);
|
|
void waitHandler(int socket);
|
|
void watchHandler(int socket);
|
|
|
|
bool checkKeys(const std::vector<std::string>& keys) const;
|
|
// Helper function to alerts waiting workers, used in setHandler, getHandler
|
|
void wakeupWaitingClients(const std::string& key);
|
|
// Helper function used when the key is changed
|
|
// used in setHandler, addHandler, getHandler, deleteHandler
|
|
void sendKeyUpdatesToClients(
|
|
const std::string& key,
|
|
const enum WatchResponseType& type,
|
|
std::vector<uint8_t>& oldData,
|
|
std::vector<uint8_t>& newData);
|
|
std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_;
|
|
// From key -> the list of sockets waiting on the key
|
|
std::unordered_map<std::string, std::vector<int>> waitingSockets_;
|
|
// From socket -> number of keys awaited
|
|
std::unordered_map<int, size_t> keysAwaited_;
|
|
// From key -> the list of sockets watching the key
|
|
std::unordered_map<std::string, std::vector<int>> watchedSockets_;
|
|
};
|
|
|
|
// Simply start the daemon thread
|
|
TCPStoreMasterDaemon::TCPStoreMasterDaemon(Socket&& storeListenSocket)
|
|
: BackgroundThread{std::move(storeListenSocket)} {
|
|
daemonThread_ = std::thread{&TCPStoreMasterDaemon::run, this};
|
|
}
|
|
|
|
TCPStoreMasterDaemon::~TCPStoreMasterDaemon() {
|
|
dispose();
|
|
}
|
|
|
|
void TCPStoreMasterDaemon::queryFds(std::vector<struct pollfd>& fds) {
|
|
// Skipping the fds[0] and fds[1],
|
|
// fds[0] is master's listening socket
|
|
// fds[1] is control pipe's reading fd, it is not for Windows platform
|
|
for (size_t fdIdx = CONNECT_SOCKET_OFFSET; fdIdx < fds.size(); ++fdIdx) {
|
|
if (fds[fdIdx].revents == 0) {
|
|
continue;
|
|
}
|
|
|
|
// Now query the socket that has the event
|
|
try {
|
|
query(fds[fdIdx].fd);
|
|
} catch (...) {
|
|
// There was an error when processing query. Probably an exception
|
|
// occurred in recv/send what would indicate that socket on the other
|
|
// side has been closed. If the closing was due to normal exit, then
|
|
// the store should continue executing. Otherwise, if it was different
|
|
// exception, other connections will get an exception once they try to
|
|
// use the store. We will go ahead and close this connection whenever
|
|
// we hit an exception here.
|
|
|
|
// Remove all the tracking state of the close FD
|
|
for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
|
|
for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
|
|
if (*vecIt == fds[fdIdx].fd) {
|
|
vecIt = it->second.erase(vecIt);
|
|
} else {
|
|
++vecIt;
|
|
}
|
|
}
|
|
if (it->second.size() == 0) {
|
|
it = waitingSockets_.erase(it);
|
|
} else {
|
|
++it;
|
|
}
|
|
}
|
|
for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) {
|
|
if (it->first == fds[fdIdx].fd) {
|
|
it = keysAwaited_.erase(it);
|
|
} else {
|
|
++it;
|
|
}
|
|
}
|
|
fds.erase(fds.begin() + fdIdx);
|
|
sockets_.erase(sockets_.begin() + fdIdx - CONNECT_SOCKET_OFFSET);
|
|
--fdIdx;
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
|
|
// query communicates with the worker. The format
|
|
// of the query is as follows:
|
|
// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
|
|
// or, in the case of wait
|
|
// type of query | number of args | size of arg1 | arg1 | ...
|
|
void TCPStoreMasterDaemon::query(int socket) {
|
|
QueryType qt;
|
|
tcputil::recvBytes<QueryType>(socket, &qt, 1);
|
|
if (qt == QueryType::SET) {
|
|
setHandler(socket);
|
|
|
|
} else if (qt == QueryType::COMPARE_SET) {
|
|
compareSetHandler(socket);
|
|
|
|
} else if (qt == QueryType::ADD) {
|
|
addHandler(socket);
|
|
|
|
} else if (qt == QueryType::GET) {
|
|
getHandler(socket);
|
|
|
|
} else if (qt == QueryType::CHECK) {
|
|
checkHandler(socket);
|
|
|
|
} else if (qt == QueryType::WAIT) {
|
|
waitHandler(socket);
|
|
|
|
} else if (qt == QueryType::GETNUMKEYS) {
|
|
getNumKeysHandler(socket);
|
|
|
|
} else if (qt == QueryType::DELETE_KEY) {
|
|
deleteHandler(socket);
|
|
|
|
} else if (qt == QueryType::WATCH_KEY) {
|
|
watchHandler(socket);
|
|
|
|
} else {
|
|
TORCH_CHECK(false, "Unexpected query type");
|
|
}
|
|
}
|
|
|
|
void TCPStoreMasterDaemon::wakeupWaitingClients(const std::string& key) {
|
|
auto socketsToWait = waitingSockets_.find(key);
|
|
if (socketsToWait != waitingSockets_.end()) {
|
|
for (int socket : socketsToWait->second) {
|
|
if (--keysAwaited_[socket] == 0) {
|
|
tcputil::sendValue<WaitResponseType>(
|
|
socket, WaitResponseType::STOP_WAITING);
|
|
}
|
|
}
|
|
waitingSockets_.erase(socketsToWait);
|
|
}
|
|
}
|
|
|
|
void TCPStoreMasterDaemon::sendKeyUpdatesToClients(
|
|
const std::string& key,
|
|
const enum WatchResponseType& type,
|
|
std::vector<uint8_t>& oldData,
|
|
std::vector<uint8_t>& newData) {
|
|
for (int socket : watchedSockets_[key]) {
|
|
tcputil::sendValue<WatchResponseType>(socket, type);
|
|
tcputil::sendString(socket, key, true);
|
|
tcputil::sendVector<uint8_t>(socket, oldData);
|
|
tcputil::sendVector<uint8_t>(socket, newData);
|
|
}
|
|
}
|
|
|
|
void TCPStoreMasterDaemon::setHandler(int socket) {
|
|
std::string key = tcputil::recvString(socket);
|
|
std::vector<uint8_t> newData = tcputil::recvVector<uint8_t>(socket);
|
|
std::vector<uint8_t> oldData;
|
|
bool newKey = true;
|
|
auto it = tcpStore_.find(key);
|
|
if (it != tcpStore_.end()) {
|
|
oldData = it->second;
|
|
newKey = false;
|
|
}
|
|
tcpStore_[key] = newData;
|
|
// On "set", wake up all clients that have been waiting
|
|
wakeupWaitingClients(key);
|
|
// Send key update to all watching clients
|
|
newKey ? sendKeyUpdatesToClients(
|
|
key, WatchResponseType::KEY_CREATED, oldData, newData)
|
|
: sendKeyUpdatesToClients(
|
|
key, WatchResponseType::KEY_UPDATED, oldData, newData);
|
|
}
|
|
|
|
void TCPStoreMasterDaemon::compareSetHandler(int socket) {
|
|
std::string key = tcputil::recvString(socket);
|
|
std::vector<uint8_t> currentValue = tcputil::recvVector<uint8_t>(socket);
|
|
std::vector<uint8_t> newValue = tcputil::recvVector<uint8_t>(socket);
|
|
|
|
auto pos = tcpStore_.find(key);
|
|
if (pos == tcpStore_.end()) {
|
|
if (currentValue.empty()) {
|
|
tcpStore_[key] = newValue;
|
|
|
|
// Send key update to all watching clients
|
|
sendKeyUpdatesToClients(
|
|
key, WatchResponseType::KEY_CREATED, currentValue, newValue);
|
|
tcputil::sendVector<uint8_t>(socket, newValue);
|
|
} else {
|
|
// TODO: This code path is not ideal as we are "lying" to the caller in
|
|
// case the key does not exist. We should come up with a working solution.
|
|
tcputil::sendVector<uint8_t>(socket, currentValue);
|
|
}
|
|
} else {
|
|
if (pos->second == currentValue) {
|
|
pos->second = std::move(newValue);
|
|
|
|
// Send key update to all watching clients
|
|
sendKeyUpdatesToClients(
|
|
key, WatchResponseType::KEY_UPDATED, currentValue, pos->second);
|
|
}
|
|
tcputil::sendVector<uint8_t>(socket, pos->second);
|
|
}
|
|
}
|
|
|
|
void TCPStoreMasterDaemon::addHandler(int socket) {
|
|
std::string key = tcputil::recvString(socket);
|
|
int64_t addVal = tcputil::recvValue<int64_t>(socket);
|
|
|
|
bool newKey = true;
|
|
std::vector<uint8_t> oldData;
|
|
auto it = tcpStore_.find(key);
|
|
if (it != tcpStore_.end()) {
|
|
oldData = it->second;
|
|
auto buf = reinterpret_cast<const char*>(it->second.data());
|
|
auto len = it->second.size();
|
|
addVal += std::stoll(std::string(buf, len));
|
|
newKey = false;
|
|
}
|
|
auto addValStr = std::to_string(addVal);
|
|
std::vector<uint8_t> newData =
|
|
std::vector<uint8_t>(addValStr.begin(), addValStr.end());
|
|
tcpStore_[key] = newData;
|
|
// Now send the new value
|
|
tcputil::sendValue<int64_t>(socket, addVal);
|
|
// On "add", wake up all clients that have been waiting
|
|
wakeupWaitingClients(key);
|
|
// Send key update to all watching clients
|
|
newKey ? sendKeyUpdatesToClients(
|
|
key, WatchResponseType::KEY_CREATED, oldData, newData)
|
|
: sendKeyUpdatesToClients(
|
|
key, WatchResponseType::KEY_UPDATED, oldData, newData);
|
|
}
|
|
|
|
void TCPStoreMasterDaemon::getHandler(int socket) const {
|
|
std::string key = tcputil::recvString(socket);
|
|
auto data = tcpStore_.at(key);
|
|
tcputil::sendVector<uint8_t>(socket, data);
|
|
}
|
|
|
|
void TCPStoreMasterDaemon::getNumKeysHandler(int socket) const {
|
|
tcputil::sendValue<int64_t>(socket, tcpStore_.size());
|
|
}
|
|
|
|
void TCPStoreMasterDaemon::deleteHandler(int socket) {
|
|
std::string key = tcputil::recvString(socket);
|
|
auto it = tcpStore_.find(key);
|
|
if (it != tcpStore_.end()) {
|
|
std::vector<uint8_t> oldData = it->second;
|
|
// Send key update to all watching clients
|
|
std::vector<uint8_t> newData;
|
|
sendKeyUpdatesToClients(
|
|
key, WatchResponseType::KEY_DELETED, oldData, newData);
|
|
}
|
|
auto numDeleted = tcpStore_.erase(key);
|
|
tcputil::sendValue<int64_t>(socket, numDeleted);
|
|
}
|
|
|
|
void TCPStoreMasterDaemon::checkHandler(int socket) const {
|
|
SizeType nargs;
|
|
tcputil::recvBytes<SizeType>(socket, &nargs, 1);
|
|
std::vector<std::string> keys(nargs);
|
|
for (const auto i : c10::irange(nargs)) {
|
|
keys[i] = tcputil::recvString(socket);
|
|
}
|
|
// Now we have received all the keys
|
|
if (checkKeys(keys)) {
|
|
tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::READY);
|
|
} else {
|
|
tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::NOT_READY);
|
|
}
|
|
}
|
|
|
|
void TCPStoreMasterDaemon::waitHandler(int socket) {
|
|
SizeType nargs;
|
|
tcputil::recvBytes<SizeType>(socket, &nargs, 1);
|
|
std::vector<std::string> keys(nargs);
|
|
for (const auto i : c10::irange(nargs)) {
|
|
keys[i] = tcputil::recvString(socket);
|
|
}
|
|
if (checkKeys(keys)) {
|
|
tcputil::sendValue<WaitResponseType>(
|
|
socket, WaitResponseType::STOP_WAITING);
|
|
} else {
|
|
int numKeysToAwait = 0;
|
|
for (auto& key : keys) {
|
|
// Only count keys that have not already been set
|
|
if (tcpStore_.find(key) == tcpStore_.end()) {
|
|
waitingSockets_[key].push_back(socket);
|
|
numKeysToAwait++;
|
|
}
|
|
}
|
|
keysAwaited_[socket] = numKeysToAwait;
|
|
}
|
|
}
|
|
|
|
void TCPStoreMasterDaemon::watchHandler(int socket) {
|
|
std::string key = tcputil::recvString(socket);
|
|
|
|
// Record the socket to respond to when the key is updated
|
|
watchedSockets_[key].push_back(socket);
|
|
|
|
// Send update to TCPStoreWorkerDaemon on client
|
|
tcputil::sendValue<WatchResponseType>(
|
|
socket, WatchResponseType::KEY_CALLBACK_REGISTERED);
|
|
}
|
|
|
|
bool TCPStoreMasterDaemon::checkKeys(
|
|
const std::vector<std::string>& keys) const {
|
|
return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) {
|
|
return tcpStore_.count(s) > 0;
|
|
});
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
void TCPStoreMasterDaemon::run() {
|
|
std::vector<struct pollfd> fds;
|
|
tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
|
|
|
|
// receive the queries
|
|
bool finished = false;
|
|
while (!finished) {
|
|
for (const auto i : c10::irange(sockets_.size())) {
|
|
fds[i].revents = 0;
|
|
}
|
|
|
|
int res;
|
|
SYSCHECK_ERR_RETURN_NEG1(
|
|
res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
|
|
if (res == 0) {
|
|
auto rv = WaitForSingleObject(ghStopEvent_, 0);
|
|
if (rv != WAIT_TIMEOUT) {
|
|
finished = true;
|
|
break;
|
|
}
|
|
continue;
|
|
}
|
|
|
|
// TCPStore's listening socket has an event and it should now be able to
|
|
// accept new connections.
|
|
if (fds[0].revents != 0) {
|
|
if (!(fds[0].revents & POLLIN)) {
|
|
throw std::system_error(
|
|
ECONNABORTED,
|
|
std::system_category(),
|
|
"Unexpected poll revent on the master's listening socket: " +
|
|
std::to_string(fds[0].revents));
|
|
}
|
|
Socket socket = storeListenSocket_.accept();
|
|
int rawSocket = socket.handle();
|
|
sockets_.emplace_back(std::move(socket));
|
|
tcputil::addPollfd(fds, rawSocket, POLLIN);
|
|
}
|
|
queryFds(fds);
|
|
}
|
|
}
|
|
#else
|
|
void TCPStoreMasterDaemon::run() {
|
|
std::vector<struct pollfd> fds;
|
|
tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
|
|
// Although we haven't found any documentation or literature describing this,
|
|
// we've seen cases that, under certain circumstances, the read end of the
|
|
// pipe won't receive POLLHUP when the write end is closed. However, under
|
|
// the same circumstances, writing to the pipe will guarantee POLLIN to be
|
|
// received on the read end.
|
|
//
|
|
// For more reliable termination, the main thread will write a byte to the
|
|
// pipe before closing it, and the background thread will poll for both
|
|
// POLLIN and POLLHUP.
|
|
tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP);
|
|
|
|
// receive the queries
|
|
bool finished = false;
|
|
while (!finished) {
|
|
for (const auto i : c10::irange(sockets_.size())) {
|
|
fds[i].revents = 0;
|
|
}
|
|
|
|
SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
|
|
|
|
// TCPStore's listening socket has an event and it should now be able to
|
|
// accept new connections.
|
|
if (fds[0].revents != 0) {
|
|
if (fds[0].revents ^ POLLIN) {
|
|
throw std::system_error(
|
|
ECONNABORTED,
|
|
std::system_category(),
|
|
"Unexpected poll revent on the master's listening socket: " +
|
|
std::to_string(fds[0].revents));
|
|
}
|
|
Socket socket = storeListenSocket_.accept();
|
|
int rawSocket = socket.handle();
|
|
sockets_.emplace_back(std::move(socket));
|
|
tcputil::addPollfd(fds, rawSocket, POLLIN);
|
|
}
|
|
|
|
// The pipe receives an event which tells us to shutdown the daemon
|
|
if (fds[1].revents != 0) {
|
|
// The main thread will write a byte to the pipe then close it before
|
|
// joining the background thread
|
|
if (fds[1].revents & ~(POLLIN | POLLHUP)) {
|
|
throw std::system_error(
|
|
ECONNABORTED,
|
|
std::system_category(),
|
|
"Unexpected poll revent on the control pipe's reading fd: " +
|
|
std::to_string(fds[1].revents));
|
|
}
|
|
finished = true;
|
|
break;
|
|
}
|
|
queryFds(fds);
|
|
}
|
|
}
|
|
#endif
|
|
|
|
// Separate thread that is launched on all instances (including master)
|
|
// Right now only handles callbacks registered from watchKey()
|
|
class TCPStoreWorkerDaemon : public BackgroundThread {
|
|
public:
|
|
explicit TCPStoreWorkerDaemon(Socket&& listenSocket);
|
|
~TCPStoreWorkerDaemon() override;
|
|
// Set the callback to run key change
|
|
void setCallback(std::string key, WatchKeyCallback cb);
|
|
void waitForCallbackRegistration() {
|
|
// Block until callback has been registered successfully
|
|
std::unique_lock<std::mutex> callbackRegistrationLock(
|
|
callbackRegistrationMutex_);
|
|
callbackRegisteredCV_.wait(
|
|
callbackRegistrationLock, [&] { return callbackRegisteredData_; });
|
|
|
|
// Reset payload for next callback
|
|
callbackRegisteredData_ = false;
|
|
}
|
|
void setCallbackRegistered() {
|
|
{
|
|
std::unique_lock<std::mutex> callbackRegistrationLock(
|
|
callbackRegistrationMutex_);
|
|
callbackRegisteredData_ = true;
|
|
}
|
|
callbackRegisteredCV_.notify_one();
|
|
}
|
|
|
|
private:
|
|
void run();
|
|
void callbackHandler(int socket);
|
|
// List of callbacks map each watched key
|
|
std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_{};
|
|
std::mutex keyToCallbacksMutex_{};
|
|
std::mutex callbackRegistrationMutex_{};
|
|
std::condition_variable callbackRegisteredCV_{};
|
|
bool callbackRegisteredData_ = false;
|
|
};
|
|
|
|
// TCPStoreListener class methods
|
|
TCPStoreWorkerDaemon::TCPStoreWorkerDaemon(Socket&& listenSocket)
|
|
: BackgroundThread{std::move(listenSocket)} {
|
|
daemonThread_ = std::thread{&TCPStoreWorkerDaemon::run, this};
|
|
}
|
|
|
|
TCPStoreWorkerDaemon::~TCPStoreWorkerDaemon() {
|
|
dispose();
|
|
}
|
|
|
|
void TCPStoreWorkerDaemon::setCallback(
|
|
std::string key,
|
|
WatchKeyCallback callback) {
|
|
const std::lock_guard<std::mutex> lock(keyToCallbacksMutex_);
|
|
keyToCallbacks_[key] = callback;
|
|
}
|
|
|
|
// Runs all the callbacks that the worker has registered
|
|
void TCPStoreWorkerDaemon::callbackHandler(int socket) {
|
|
auto watchResponse = tcputil::recvValue<WatchResponseType>(socket);
|
|
if (watchResponse == WatchResponseType::KEY_CALLBACK_REGISTERED) {
|
|
// Notify the waiting "watchKey" operation to return
|
|
setCallbackRegistered();
|
|
return;
|
|
}
|
|
std::string key = tcputil::recvString(socket);
|
|
std::vector<uint8_t> currentValueVec = tcputil::recvVector<uint8_t>(socket);
|
|
std::vector<uint8_t> newValueVec = tcputil::recvVector<uint8_t>(socket);
|
|
c10::optional<std::string> currentValue;
|
|
if (watchResponse == WatchResponseType::KEY_CREATED) {
|
|
assert(currentValueVec.empty());
|
|
currentValue = c10::nullopt;
|
|
} else {
|
|
currentValue = std::string(currentValueVec.begin(), currentValueVec.end());
|
|
}
|
|
c10::optional<std::string> newValue;
|
|
if (watchResponse == WatchResponseType::KEY_DELETED) {
|
|
assert(newValueVec.empty());
|
|
newValue = c10::nullopt;
|
|
} else {
|
|
newValue = std::string(newValueVec.begin(), newValueVec.end());
|
|
}
|
|
const std::lock_guard<std::mutex> lock(keyToCallbacksMutex_);
|
|
keyToCallbacks_.at(key)(currentValue, newValue);
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
void TCPStoreWorkerDaemon::run() {
|
|
std::vector<struct pollfd> fds;
|
|
tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
|
|
|
|
while (true) {
|
|
// Check control and exit early if triggered
|
|
int res;
|
|
SYSCHECK_ERR_RETURN_NEG1(
|
|
res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
|
|
if (res == 0) {
|
|
auto rvPoll = WaitForSingleObject(ghStopEvent_, 0);
|
|
if (rvPoll != WAIT_TIMEOUT) {
|
|
break;
|
|
}
|
|
continue;
|
|
}
|
|
|
|
// if connection is closed gracefully by master, peeked data will return 0
|
|
char data;
|
|
int ret = recv(fds[0].fd, &data, 1, MSG_PEEK);
|
|
if (ret == 0) {
|
|
auto rvData = WaitForSingleObject(ghStopEvent_, 0);
|
|
if (rvData != WAIT_TIMEOUT) {
|
|
break;
|
|
}
|
|
continue;
|
|
}
|
|
|
|
// valid request, perform callback logic
|
|
callbackHandler(fds[0].fd);
|
|
}
|
|
}
|
|
#else
|
|
void TCPStoreWorkerDaemon::run() {
|
|
std::vector<struct pollfd> fds;
|
|
// Although we haven't found any documentation or literature describing this,
|
|
// we've seen cases that, under certain circumstances, the read end of the
|
|
// pipe won't receive POLLHUP when the write end is closed. However, under
|
|
// the same circumstances, writing to the pipe will guarantee POLLIN to be
|
|
// received on the read end.
|
|
//
|
|
// For more reliable termination, the main thread will write a byte to the
|
|
// pipe before closing it, and the background thread will poll for both
|
|
// POLLIN and POLLHUP.
|
|
tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP);
|
|
tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
|
|
|
|
while (true) {
|
|
SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
|
|
|
|
// Check control and exit early if triggered
|
|
// The pipe receives an event which tells us to shutdown the listener thread
|
|
if (fds[0].revents != 0) {
|
|
// The main thread will write a byte to the pipe then close it before
|
|
// joining the background thread
|
|
if (fds[0].revents & ~(POLLIN | POLLHUP)) {
|
|
throw std::system_error(
|
|
ECONNABORTED,
|
|
std::system_category(),
|
|
"Unexpected poll revent on the control pipe's reading fd: " +
|
|
std::to_string(fds[0].revents));
|
|
}
|
|
break;
|
|
}
|
|
|
|
// if connection is closed gracefully by master, peeked data will return 0
|
|
char data;
|
|
int ret = recv(fds[1].fd, &data, 1, MSG_PEEK);
|
|
if (ret == 0) {
|
|
continue;
|
|
}
|
|
|
|
// valid request, perform callback logic
|
|
callbackHandler(fds[1].fd);
|
|
}
|
|
}
|
|
#endif
|
|
|
|
} // namespace
|
|
|
|
// Manages the lifecycle of a server daemon.
|
|
class TCPServer {
|
|
public:
|
|
static std::shared_ptr<TCPServer> start(const TCPStoreOptions& opts);
|
|
|
|
std::uint16_t port() const noexcept {
|
|
return port_;
|
|
}
|
|
|
|
explicit TCPServer(
|
|
std::uint16_t port,
|
|
std::unique_ptr<TCPStoreMasterDaemon>&& daemon)
|
|
: port_{port}, daemon_{std::move(daemon)} {}
|
|
|
|
private:
|
|
std::uint16_t port_;
|
|
std::unique_ptr<TCPStoreMasterDaemon> daemon_;
|
|
|
|
// We store weak references to all TCPServers for which the caller requested
|
|
// multi-tenancy.
|
|
static std::unordered_map<std::uint16_t, std::weak_ptr<TCPServer>>
|
|
cachedServers_;
|
|
|
|
static std::mutex cache_mutex_;
|
|
};
|
|
|
|
std::unordered_map<std::uint16_t, std::weak_ptr<TCPServer>>
|
|
TCPServer::cachedServers_{};
|
|
|
|
std::mutex TCPServer::cache_mutex_{};
|
|
|
|
std::shared_ptr<TCPServer> TCPServer::start(const TCPStoreOptions& opts) {
|
|
auto startCore = [&opts]() {
|
|
Socket socket = Socket::listen(opts.port);
|
|
|
|
std::uint16_t port = socket.port();
|
|
|
|
auto daemon = std::make_unique<TCPStoreMasterDaemon>(std::move(socket));
|
|
|
|
return std::make_shared<TCPServer>(port, std::move(daemon));
|
|
};
|
|
|
|
std::shared_ptr<TCPServer> server{};
|
|
|
|
if (opts.multiTenant) {
|
|
std::lock_guard<std::mutex> guard{cache_mutex_};
|
|
|
|
// If the caller is okay with a multi-tenant store, first check if we
|
|
// already have a TCPServer running on the specified port.
|
|
if (opts.port > 0) {
|
|
auto pos = cachedServers_.find(opts.port);
|
|
if (pos != cachedServers_.end()) {
|
|
server = pos->second.lock();
|
|
if (server != nullptr) {
|
|
return server;
|
|
}
|
|
|
|
// Looks like the TCPStore has been disposed, make sure that we release
|
|
// the control block.
|
|
cachedServers_.erase(pos);
|
|
}
|
|
}
|
|
|
|
server = startCore();
|
|
|
|
cachedServers_.emplace(server->port(), server);
|
|
} else {
|
|
server = startCore();
|
|
}
|
|
|
|
return server;
|
|
}
|
|
|
|
class TCPClient {
|
|
public:
|
|
static std::unique_ptr<TCPClient> connect(
|
|
const SocketAddress& addr,
|
|
const TCPStoreOptions& opts);
|
|
|
|
void sendCommand(QueryType type) {
|
|
tcputil::sendValue<QueryType>(socket_.handle(), type);
|
|
}
|
|
|
|
void sendCommandForKey(QueryType type, const std::string& key);
|
|
|
|
void sendBytes(const std::vector<std::uint8_t>& value) {
|
|
tcputil::sendVector<std::uint8_t>(socket_.handle(), value);
|
|
}
|
|
|
|
void sendStrings(c10::ArrayRef<std::string> value);
|
|
|
|
template <typename T>
|
|
void sendValue(const T& value) {
|
|
tcputil::sendValue<T>(socket_.handle(), value);
|
|
}
|
|
|
|
std::vector<std::uint8_t> receiveBits() {
|
|
return tcputil::recvVector<std::uint8_t>(socket_.handle());
|
|
}
|
|
|
|
template <typename T>
|
|
T receiveValue() {
|
|
return tcputil::recvValue<T>(socket_.handle());
|
|
}
|
|
|
|
void setTimeout(std::chrono::milliseconds value);
|
|
|
|
explicit TCPClient(Socket&& socket) : socket_{std::move(socket)} {}
|
|
|
|
private:
|
|
Socket socket_;
|
|
};
|
|
|
|
std::unique_ptr<TCPClient> TCPClient::connect(
|
|
const SocketAddress& addr,
|
|
const TCPStoreOptions& opts) {
|
|
auto timeout = std::chrono::duration_cast<std::chrono::seconds>(opts.timeout);
|
|
Socket socket = Socket::connect(addr.host,
|
|
addr.port,
|
|
SocketOptions{}.connect_timeout(timeout));
|
|
|
|
return std::make_unique<TCPClient>(std::move(socket));
|
|
}
|
|
|
|
void TCPClient::sendCommandForKey(QueryType type, const std::string& key) {
|
|
tcputil::sendValue<QueryType>(socket_.handle(), type);
|
|
|
|
bool withValue = type == QueryType::SET || type == QueryType::COMPARE_SET ||
|
|
type == QueryType::ADD;
|
|
|
|
tcputil::sendString(socket_.handle(), key, withValue);
|
|
}
|
|
|
|
void TCPClient::sendStrings(c10::ArrayRef<std::string> value) {
|
|
std::size_t size = value.size();
|
|
|
|
tcputil::sendBytes<std::size_t>(socket_.handle(), &size, 1, size > 0);
|
|
|
|
if (value.empty()) {
|
|
return;
|
|
}
|
|
|
|
for (auto pos = value.begin(), last = value.end() - 1; pos <= last; ++pos) {
|
|
tcputil::sendString(socket_.handle(), *pos, pos != last);
|
|
}
|
|
}
|
|
|
|
void TCPClient::setTimeout(std::chrono::milliseconds value) {
|
|
if (value == std::chrono::milliseconds::zero()) {
|
|
return;
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
struct timeval timeoutTV = {
|
|
static_cast<long>(value.count() / 1000),
|
|
static_cast<long>((value.count() % 1000) * 1000)};
|
|
#else
|
|
struct timeval timeoutTV = {
|
|
.tv_sec = value.count() / 1000,
|
|
.tv_usec = static_cast<suseconds_t>((value.count() % 1000) * 1000),
|
|
};
|
|
#endif
|
|
SYSCHECK_ERR_RETURN_NEG1(::setsockopt(
|
|
socket_.handle(),
|
|
SOL_SOCKET,
|
|
SO_RCVTIMEO,
|
|
reinterpret_cast<char*>(&timeoutTV),
|
|
sizeof(timeoutTV)));
|
|
}
|
|
|
|
class TCPCallbackClient {
|
|
public:
|
|
static std::unique_ptr<TCPCallbackClient> connect(
|
|
const SocketAddress& addr,
|
|
const TCPStoreOptions& opts);
|
|
|
|
void setCallback(const std::string& key, WatchKeyCallback callback);
|
|
|
|
explicit TCPCallbackClient(
|
|
int rawSocket,
|
|
std::unique_ptr<TCPStoreWorkerDaemon>&& daemon)
|
|
: rawSocket_{rawSocket}, daemon_{std::move(daemon)} {}
|
|
|
|
private:
|
|
int rawSocket_;
|
|
std::unique_ptr<TCPStoreWorkerDaemon> daemon_;
|
|
std::mutex mutex_;
|
|
};
|
|
|
|
std::unique_ptr<TCPCallbackClient> TCPCallbackClient::connect(
|
|
const SocketAddress& addr,
|
|
const TCPStoreOptions& opts) {
|
|
auto timeout = std::chrono::duration_cast<std::chrono::seconds>(opts.timeout);
|
|
Socket socket = Socket::connect(addr.host,
|
|
addr.port,
|
|
SocketOptions{}.connect_timeout(timeout));
|
|
|
|
int rawSocket = socket.handle();
|
|
|
|
auto daemon = std::make_unique<TCPStoreWorkerDaemon>(std::move(socket));
|
|
|
|
return std::make_unique<TCPCallbackClient>(rawSocket, std::move(daemon));
|
|
}
|
|
|
|
void TCPCallbackClient::setCallback(
|
|
const std::string& key,
|
|
WatchKeyCallback callback) {
|
|
std::lock_guard<std::mutex> guard{mutex_};
|
|
|
|
daemon_->setCallback(key, callback);
|
|
|
|
tcputil::sendValue<QueryType>(rawSocket_, QueryType::WATCH_KEY);
|
|
|
|
tcputil::sendString(rawSocket_, key);
|
|
|
|
daemon_->waitForCallbackRegistration();
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
using detail::Socket;
|
|
|
|
// TCPStore class methods
|
|
TCPStore::TCPStore(
|
|
const std::string& masterAddr,
|
|
std::uint16_t masterPort,
|
|
c10::optional<int> numWorkers,
|
|
bool isServer,
|
|
const std::chrono::milliseconds& timeout,
|
|
bool waitWorkers)
|
|
: TCPStore{
|
|
masterAddr,
|
|
TCPStoreOptions{
|
|
masterPort,
|
|
isServer,
|
|
numWorkers ? c10::optional<std::size_t>(*numWorkers)
|
|
: c10::nullopt,
|
|
waitWorkers,
|
|
timeout}} {}
|
|
|
|
TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts)
|
|
: Store{opts.timeout},
|
|
addr_{std::move(host)},
|
|
numWorkers_{opts.numWorkers} {
|
|
Socket::initialize();
|
|
|
|
if (opts.isServer) {
|
|
server_ = detail::TCPServer::start(opts);
|
|
|
|
addr_.port = server_->port();
|
|
} else {
|
|
addr_.port = opts.port;
|
|
}
|
|
|
|
client_ = detail::TCPClient::connect(addr_, opts);
|
|
|
|
if (opts.waitWorkers) {
|
|
waitForWorkers();
|
|
}
|
|
|
|
callbackClient_ = detail::TCPCallbackClient::connect(addr_, opts);
|
|
}
|
|
|
|
TCPStore::~TCPStore() = default;
|
|
|
|
void TCPStore::waitForWorkers() {
|
|
if (numWorkers_ == c10::nullopt) {
|
|
return;
|
|
}
|
|
|
|
incrementValueBy(initKey_, 1);
|
|
|
|
// Let server block until all workers have completed, this ensures that
|
|
// the server daemon thread is always running until the very end
|
|
if (server_) {
|
|
const auto start = std::chrono::steady_clock::now();
|
|
while (true) {
|
|
// TODO: Any chance to make this cleaner?
|
|
std::vector<uint8_t> value = doGet(initKey_);
|
|
auto buf = reinterpret_cast<const char*>(value.data());
|
|
auto len = value.size();
|
|
int numWorkersCompleted = std::stoi(std::string(buf, len));
|
|
if (numWorkersCompleted >= *numWorkers_) {
|
|
break;
|
|
}
|
|
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
|
std::chrono::steady_clock::now() - start);
|
|
if (timeout_ != kNoTimeout && elapsed > timeout_) {
|
|
break;
|
|
}
|
|
/* sleep override */
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
}
|
|
}
|
|
}
|
|
|
|
void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
client_->sendCommandForKey(detail::QueryType::SET, keyPrefix_ + key);
|
|
client_->sendBytes(data);
|
|
}
|
|
|
|
std::vector<uint8_t> TCPStore::compareSet(
|
|
const std::string& key,
|
|
const std::vector<uint8_t>& expectedValue,
|
|
const std::vector<uint8_t>& desiredValue) {
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
client_->sendCommandForKey(detail::QueryType::COMPARE_SET, keyPrefix_ + key);
|
|
client_->sendBytes(expectedValue);
|
|
client_->sendBytes(desiredValue);
|
|
|
|
return client_->receiveBits();
|
|
}
|
|
|
|
std::vector<uint8_t> TCPStore::get(const std::string& key) {
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
return doGet(keyPrefix_ + key);
|
|
}
|
|
|
|
std::vector<uint8_t> TCPStore::doGet(const std::string& key) {
|
|
doWait(key, timeout_);
|
|
client_->sendCommandForKey(detail::QueryType::GET, key);
|
|
return client_->receiveBits();
|
|
}
|
|
|
|
int64_t TCPStore::add(const std::string& key, int64_t value) {
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
return incrementValueBy(keyPrefix_ + key, value);
|
|
}
|
|
|
|
bool TCPStore::deleteKey(const std::string& key) {
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
client_->sendCommandForKey(detail::QueryType::DELETE_KEY, keyPrefix_ + key);
|
|
auto numDeleted = client_->receiveValue<std::int64_t>();
|
|
return numDeleted == 1;
|
|
}
|
|
|
|
void TCPStore::watchKey(const std::string& key, WatchKeyCallback callback) {
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
callbackClient_->setCallback(keyPrefix_ + key, callback);
|
|
}
|
|
|
|
int64_t TCPStore::incrementValueBy(const std::string& key, int64_t delta) {
|
|
client_->sendCommandForKey(detail::QueryType::ADD, key);
|
|
client_->sendValue<std::int64_t>(delta);
|
|
return client_->receiveValue<std::int64_t>();
|
|
}
|
|
|
|
int64_t TCPStore::getNumKeys() {
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
client_->sendCommand(detail::QueryType::GETNUMKEYS);
|
|
return client_->receiveValue<std::int64_t>();
|
|
}
|
|
|
|
bool TCPStore::check(const std::vector<std::string>& keys) {
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
std::vector<std::string> prefixedKeys{};
|
|
prefixedKeys.reserve(keys.size());
|
|
for (const std::string& key : keys) {
|
|
prefixedKeys.emplace_back(keyPrefix_ + key);
|
|
}
|
|
|
|
client_->sendCommand(detail::QueryType::CHECK);
|
|
client_->sendStrings(prefixedKeys);
|
|
|
|
auto response = client_->receiveValue<detail::CheckResponseType>();
|
|
if (response == detail::CheckResponseType::READY) {
|
|
return true;
|
|
}
|
|
if (response == detail::CheckResponseType::NOT_READY) {
|
|
return false;
|
|
}
|
|
TORCH_CHECK(false, "ready or not_ready response expected");
|
|
}
|
|
|
|
void TCPStore::wait(const std::vector<std::string>& keys) {
|
|
wait(keys, timeout_);
|
|
}
|
|
|
|
void TCPStore::wait(
|
|
const std::vector<std::string>& keys,
|
|
const std::chrono::milliseconds& timeout) {
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
std::vector<std::string> prefixedKeys{};
|
|
prefixedKeys.reserve(keys.size());
|
|
for (const std::string& key : keys) {
|
|
prefixedKeys.emplace_back(keyPrefix_ + key);
|
|
}
|
|
|
|
doWait(prefixedKeys, timeout);
|
|
}
|
|
|
|
void TCPStore::doWait(
|
|
c10::ArrayRef<std::string> keys,
|
|
std::chrono::milliseconds timeout) {
|
|
// TODO: Should we revert to the original timeout at the end of the call?
|
|
client_->setTimeout(timeout);
|
|
|
|
client_->sendCommand(detail::QueryType::WAIT);
|
|
client_->sendStrings(keys);
|
|
|
|
auto response = client_->receiveValue<detail::WaitResponseType>();
|
|
if (response != detail::WaitResponseType::STOP_WAITING) {
|
|
TORCH_CHECK(false, "Stop_waiting response is expected");
|
|
}
|
|
}
|
|
|
|
} // namespace c10d
|