Files
pytorch/torch/csrc/distributed/c10d/TCPStore.cpp
Yifu Wang 27ea79b8a5 Use a more reliable signaling mechansim to stop TCPStore background threads (#76973)
Summary:
The main thread establishes a dedicated stop signal pipe for each TCPStore
background thread. Before joining a background thread, the main thread would
close the write end of the corresponding pipe, expecting the background the
thread to receive POLLHUP. Upon receiving POLLHUP, the background thread would
break the loop and graceful exit.

Although we haven't found any documentation or literature backing this, we have
evidence that under certain circumstances, the read end of the pipe won't
receive POLLUP when the write end is closed. However, under the same
circumstances, writing to the pipe will guarantee POLLIN to be received on the
read end.

Test Plan: Manually tested

Differential Revision: D36208897

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76973
Approved by: https://github.com/cbalioglu
2022-05-10 23:24:08 +00:00

1166 lines
34 KiB
C++

#include <c10/util/irange.h>
#include <c10d/TCPStore.hpp>
#include <fcntl.h>
#include <algorithm>
#include <array>
#include <system_error>
#include <thread>
#include <unordered_map>
#include <utility>
#ifdef _WIN32
#include <io.h>
#include <winsock2.h>
#else
#include <poll.h>
#include <unistd.h>
#endif
#ifdef _WIN32
#include <c10d/WinSockUtils.hpp>
#else
#include <c10d/UnixSockUtils.hpp>
#endif
#include <c10d/socket.h>
namespace c10d {
namespace detail {
namespace {
// Abstract base class to handle thread state for TCPStoreMasterDaemon and
// TCPStoreWorkerDaemon. Contains the windows/unix implementations to signal a
// shutdown sequence for the thread
class BackgroundThread {
public:
explicit BackgroundThread(Socket&& storeListenSocket);
virtual ~BackgroundThread() = 0;
protected:
void dispose();
Socket storeListenSocket_;
std::thread daemonThread_{};
std::vector<Socket> sockets_{};
#ifdef _WIN32
const std::chrono::milliseconds checkTimeout_ = std::chrono::milliseconds{10};
HANDLE ghStopEvent_{};
#else
std::array<int, 2> controlPipeFd_{{-1, -1}};
#endif
private:
// Initialization for shutdown signal
void initStopSignal();
// Triggers the shutdown signal
void stop();
// Joins the thread
void join();
// Clean up the shutdown signal
void closeStopSignal();
};
// Background thread parent class methods
BackgroundThread::BackgroundThread(Socket&& storeListenSocket)
: storeListenSocket_{std::move(storeListenSocket)} {
// Signal instance destruction to the daemon thread.
initStopSignal();
}
BackgroundThread::~BackgroundThread() = default;
// WARNING:
// Since we rely on the subclass for the daemon thread clean-up, we cannot
// destruct our member variables in the destructor. The subclass must call
// dispose() in its own destructor.
void BackgroundThread::dispose() {
// Stop the run
stop();
// Join the thread
join();
// Close unclosed sockets
sockets_.clear();
// Now close the rest control pipe
closeStopSignal();
}
void BackgroundThread::join() {
daemonThread_.join();
}
#ifdef _WIN32
void BackgroundThread::initStopSignal() {
ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL);
if (ghStopEvent_ == NULL) {
TORCH_CHECK(
false,
"Failed to create the control pipe to start the "
"BackgroundThread run");
}
}
void BackgroundThread::closeStopSignal() {
CloseHandle(ghStopEvent_);
}
void BackgroundThread::stop() {
SetEvent(ghStopEvent_);
}
#else
void BackgroundThread::initStopSignal() {
if (pipe(controlPipeFd_.data()) == -1) {
TORCH_CHECK(
false,
"Failed to create the control pipe to start the "
"BackgroundThread run");
}
}
void BackgroundThread::closeStopSignal() {
for (int fd : controlPipeFd_) {
if (fd != -1) {
::close(fd);
}
}
}
void BackgroundThread::stop() {
if (controlPipeFd_[1] != -1) {
::write(controlPipeFd_[1], "\0", 1);
// close the write end of the pipe
::close(controlPipeFd_[1]);
controlPipeFd_[1] = -1;
}
}
#endif
enum class QueryType : uint8_t {
SET,
COMPARE_SET,
GET,
ADD,
CHECK,
WAIT,
GETNUMKEYS,
WATCH_KEY,
DELETE_KEY,
};
enum class CheckResponseType : uint8_t { READY, NOT_READY };
enum class WaitResponseType : uint8_t { STOP_WAITING };
enum class WatchResponseType : uint8_t {
KEY_UPDATED,
KEY_CREATED,
KEY_DELETED,
KEY_CALLBACK_REGISTERED
};
// Separate thread that is only launched on master
class TCPStoreMasterDaemon : public BackgroundThread {
public:
explicit TCPStoreMasterDaemon(Socket&& storeListenSocket);
~TCPStoreMasterDaemon() override;
private:
void run();
void queryFds(std::vector<struct pollfd>& fds);
void query(int socket);
// The master runs on a single thread so only
// one handler can be executed at a time
void setHandler(int socket);
void compareSetHandler(int socket);
void addHandler(int socket);
void getHandler(int socket) const;
void checkHandler(int socket) const;
void getNumKeysHandler(int socket) const;
void deleteHandler(int socket);
void waitHandler(int socket);
void watchHandler(int socket);
bool checkKeys(const std::vector<std::string>& keys) const;
// Helper function to alerts waiting workers, used in setHandler, getHandler
void wakeupWaitingClients(const std::string& key);
// Helper function used when the key is changed
// used in setHandler, addHandler, getHandler, deleteHandler
void sendKeyUpdatesToClients(
const std::string& key,
const enum WatchResponseType& type,
std::vector<uint8_t>& oldData,
std::vector<uint8_t>& newData);
std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_;
// From key -> the list of sockets waiting on the key
std::unordered_map<std::string, std::vector<int>> waitingSockets_;
// From socket -> number of keys awaited
std::unordered_map<int, size_t> keysAwaited_;
// From key -> the list of sockets watching the key
std::unordered_map<std::string, std::vector<int>> watchedSockets_;
};
// Simply start the daemon thread
TCPStoreMasterDaemon::TCPStoreMasterDaemon(Socket&& storeListenSocket)
: BackgroundThread{std::move(storeListenSocket)} {
daemonThread_ = std::thread{&TCPStoreMasterDaemon::run, this};
}
TCPStoreMasterDaemon::~TCPStoreMasterDaemon() {
dispose();
}
void TCPStoreMasterDaemon::queryFds(std::vector<struct pollfd>& fds) {
// Skipping the fds[0] and fds[1],
// fds[0] is master's listening socket
// fds[1] is control pipe's reading fd, it is not for Windows platform
for (size_t fdIdx = CONNECT_SOCKET_OFFSET; fdIdx < fds.size(); ++fdIdx) {
if (fds[fdIdx].revents == 0) {
continue;
}
// Now query the socket that has the event
try {
query(fds[fdIdx].fd);
} catch (...) {
// There was an error when processing query. Probably an exception
// occurred in recv/send what would indicate that socket on the other
// side has been closed. If the closing was due to normal exit, then
// the store should continue executing. Otherwise, if it was different
// exception, other connections will get an exception once they try to
// use the store. We will go ahead and close this connection whenever
// we hit an exception here.
// Remove all the tracking state of the close FD
for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
if (*vecIt == fds[fdIdx].fd) {
vecIt = it->second.erase(vecIt);
} else {
++vecIt;
}
}
if (it->second.size() == 0) {
it = waitingSockets_.erase(it);
} else {
++it;
}
}
for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) {
if (it->first == fds[fdIdx].fd) {
it = keysAwaited_.erase(it);
} else {
++it;
}
}
fds.erase(fds.begin() + fdIdx);
sockets_.erase(sockets_.begin() + fdIdx - CONNECT_SOCKET_OFFSET);
--fdIdx;
continue;
}
}
}
// query communicates with the worker. The format
// of the query is as follows:
// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
// or, in the case of wait
// type of query | number of args | size of arg1 | arg1 | ...
void TCPStoreMasterDaemon::query(int socket) {
QueryType qt;
tcputil::recvBytes<QueryType>(socket, &qt, 1);
if (qt == QueryType::SET) {
setHandler(socket);
} else if (qt == QueryType::COMPARE_SET) {
compareSetHandler(socket);
} else if (qt == QueryType::ADD) {
addHandler(socket);
} else if (qt == QueryType::GET) {
getHandler(socket);
} else if (qt == QueryType::CHECK) {
checkHandler(socket);
} else if (qt == QueryType::WAIT) {
waitHandler(socket);
} else if (qt == QueryType::GETNUMKEYS) {
getNumKeysHandler(socket);
} else if (qt == QueryType::DELETE_KEY) {
deleteHandler(socket);
} else if (qt == QueryType::WATCH_KEY) {
watchHandler(socket);
} else {
TORCH_CHECK(false, "Unexpected query type");
}
}
void TCPStoreMasterDaemon::wakeupWaitingClients(const std::string& key) {
auto socketsToWait = waitingSockets_.find(key);
if (socketsToWait != waitingSockets_.end()) {
for (int socket : socketsToWait->second) {
if (--keysAwaited_[socket] == 0) {
tcputil::sendValue<WaitResponseType>(
socket, WaitResponseType::STOP_WAITING);
}
}
waitingSockets_.erase(socketsToWait);
}
}
void TCPStoreMasterDaemon::sendKeyUpdatesToClients(
const std::string& key,
const enum WatchResponseType& type,
std::vector<uint8_t>& oldData,
std::vector<uint8_t>& newData) {
for (int socket : watchedSockets_[key]) {
tcputil::sendValue<WatchResponseType>(socket, type);
tcputil::sendString(socket, key, true);
tcputil::sendVector<uint8_t>(socket, oldData);
tcputil::sendVector<uint8_t>(socket, newData);
}
}
void TCPStoreMasterDaemon::setHandler(int socket) {
std::string key = tcputil::recvString(socket);
std::vector<uint8_t> newData = tcputil::recvVector<uint8_t>(socket);
std::vector<uint8_t> oldData;
bool newKey = true;
auto it = tcpStore_.find(key);
if (it != tcpStore_.end()) {
oldData = it->second;
newKey = false;
}
tcpStore_[key] = newData;
// On "set", wake up all clients that have been waiting
wakeupWaitingClients(key);
// Send key update to all watching clients
newKey ? sendKeyUpdatesToClients(
key, WatchResponseType::KEY_CREATED, oldData, newData)
: sendKeyUpdatesToClients(
key, WatchResponseType::KEY_UPDATED, oldData, newData);
}
void TCPStoreMasterDaemon::compareSetHandler(int socket) {
std::string key = tcputil::recvString(socket);
std::vector<uint8_t> currentValue = tcputil::recvVector<uint8_t>(socket);
std::vector<uint8_t> newValue = tcputil::recvVector<uint8_t>(socket);
auto pos = tcpStore_.find(key);
if (pos == tcpStore_.end()) {
if (currentValue.empty()) {
tcpStore_[key] = newValue;
// Send key update to all watching clients
sendKeyUpdatesToClients(
key, WatchResponseType::KEY_CREATED, currentValue, newValue);
tcputil::sendVector<uint8_t>(socket, newValue);
} else {
// TODO: This code path is not ideal as we are "lying" to the caller in
// case the key does not exist. We should come up with a working solution.
tcputil::sendVector<uint8_t>(socket, currentValue);
}
} else {
if (pos->second == currentValue) {
pos->second = std::move(newValue);
// Send key update to all watching clients
sendKeyUpdatesToClients(
key, WatchResponseType::KEY_UPDATED, currentValue, pos->second);
}
tcputil::sendVector<uint8_t>(socket, pos->second);
}
}
void TCPStoreMasterDaemon::addHandler(int socket) {
std::string key = tcputil::recvString(socket);
int64_t addVal = tcputil::recvValue<int64_t>(socket);
bool newKey = true;
std::vector<uint8_t> oldData;
auto it = tcpStore_.find(key);
if (it != tcpStore_.end()) {
oldData = it->second;
auto buf = reinterpret_cast<const char*>(it->second.data());
auto len = it->second.size();
addVal += std::stoll(std::string(buf, len));
newKey = false;
}
auto addValStr = std::to_string(addVal);
std::vector<uint8_t> newData =
std::vector<uint8_t>(addValStr.begin(), addValStr.end());
tcpStore_[key] = newData;
// Now send the new value
tcputil::sendValue<int64_t>(socket, addVal);
// On "add", wake up all clients that have been waiting
wakeupWaitingClients(key);
// Send key update to all watching clients
newKey ? sendKeyUpdatesToClients(
key, WatchResponseType::KEY_CREATED, oldData, newData)
: sendKeyUpdatesToClients(
key, WatchResponseType::KEY_UPDATED, oldData, newData);
}
void TCPStoreMasterDaemon::getHandler(int socket) const {
std::string key = tcputil::recvString(socket);
auto data = tcpStore_.at(key);
tcputil::sendVector<uint8_t>(socket, data);
}
void TCPStoreMasterDaemon::getNumKeysHandler(int socket) const {
tcputil::sendValue<int64_t>(socket, tcpStore_.size());
}
void TCPStoreMasterDaemon::deleteHandler(int socket) {
std::string key = tcputil::recvString(socket);
auto it = tcpStore_.find(key);
if (it != tcpStore_.end()) {
std::vector<uint8_t> oldData = it->second;
// Send key update to all watching clients
std::vector<uint8_t> newData;
sendKeyUpdatesToClients(
key, WatchResponseType::KEY_DELETED, oldData, newData);
}
auto numDeleted = tcpStore_.erase(key);
tcputil::sendValue<int64_t>(socket, numDeleted);
}
void TCPStoreMasterDaemon::checkHandler(int socket) const {
SizeType nargs;
tcputil::recvBytes<SizeType>(socket, &nargs, 1);
std::vector<std::string> keys(nargs);
for (const auto i : c10::irange(nargs)) {
keys[i] = tcputil::recvString(socket);
}
// Now we have received all the keys
if (checkKeys(keys)) {
tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::READY);
} else {
tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::NOT_READY);
}
}
void TCPStoreMasterDaemon::waitHandler(int socket) {
SizeType nargs;
tcputil::recvBytes<SizeType>(socket, &nargs, 1);
std::vector<std::string> keys(nargs);
for (const auto i : c10::irange(nargs)) {
keys[i] = tcputil::recvString(socket);
}
if (checkKeys(keys)) {
tcputil::sendValue<WaitResponseType>(
socket, WaitResponseType::STOP_WAITING);
} else {
int numKeysToAwait = 0;
for (auto& key : keys) {
// Only count keys that have not already been set
if (tcpStore_.find(key) == tcpStore_.end()) {
waitingSockets_[key].push_back(socket);
numKeysToAwait++;
}
}
keysAwaited_[socket] = numKeysToAwait;
}
}
void TCPStoreMasterDaemon::watchHandler(int socket) {
std::string key = tcputil::recvString(socket);
// Record the socket to respond to when the key is updated
watchedSockets_[key].push_back(socket);
// Send update to TCPStoreWorkerDaemon on client
tcputil::sendValue<WatchResponseType>(
socket, WatchResponseType::KEY_CALLBACK_REGISTERED);
}
bool TCPStoreMasterDaemon::checkKeys(
const std::vector<std::string>& keys) const {
return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) {
return tcpStore_.count(s) > 0;
});
}
#ifdef _WIN32
void TCPStoreMasterDaemon::run() {
std::vector<struct pollfd> fds;
tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
// receive the queries
bool finished = false;
while (!finished) {
for (const auto i : c10::irange(sockets_.size())) {
fds[i].revents = 0;
}
int res;
SYSCHECK_ERR_RETURN_NEG1(
res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
if (res == 0) {
auto rv = WaitForSingleObject(ghStopEvent_, 0);
if (rv != WAIT_TIMEOUT) {
finished = true;
break;
}
continue;
}
// TCPStore's listening socket has an event and it should now be able to
// accept new connections.
if (fds[0].revents != 0) {
if (!(fds[0].revents & POLLIN)) {
throw std::system_error(
ECONNABORTED,
std::system_category(),
"Unexpected poll revent on the master's listening socket: " +
std::to_string(fds[0].revents));
}
Socket socket = storeListenSocket_.accept();
int rawSocket = socket.handle();
sockets_.emplace_back(std::move(socket));
tcputil::addPollfd(fds, rawSocket, POLLIN);
}
queryFds(fds);
}
}
#else
void TCPStoreMasterDaemon::run() {
std::vector<struct pollfd> fds;
tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
// Although we haven't found any documentation or literature describing this,
// we've seen cases that, under certain circumstances, the read end of the
// pipe won't receive POLLHUP when the write end is closed. However, under
// the same circumstances, writing to the pipe will guarantee POLLIN to be
// received on the read end.
//
// For more reliable termination, the main thread will write a byte to the
// pipe before closing it, and the background thread will poll for both
// POLLIN and POLLHUP.
tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP);
// receive the queries
bool finished = false;
while (!finished) {
for (const auto i : c10::irange(sockets_.size())) {
fds[i].revents = 0;
}
SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
// TCPStore's listening socket has an event and it should now be able to
// accept new connections.
if (fds[0].revents != 0) {
if (fds[0].revents ^ POLLIN) {
throw std::system_error(
ECONNABORTED,
std::system_category(),
"Unexpected poll revent on the master's listening socket: " +
std::to_string(fds[0].revents));
}
Socket socket = storeListenSocket_.accept();
int rawSocket = socket.handle();
sockets_.emplace_back(std::move(socket));
tcputil::addPollfd(fds, rawSocket, POLLIN);
}
// The pipe receives an event which tells us to shutdown the daemon
if (fds[1].revents != 0) {
// The main thread will write a byte to the pipe then close it before
// joining the background thread
if (fds[1].revents & ~(POLLIN | POLLHUP)) {
throw std::system_error(
ECONNABORTED,
std::system_category(),
"Unexpected poll revent on the control pipe's reading fd: " +
std::to_string(fds[1].revents));
}
finished = true;
break;
}
queryFds(fds);
}
}
#endif
// Separate thread that is launched on all instances (including master)
// Right now only handles callbacks registered from watchKey()
class TCPStoreWorkerDaemon : public BackgroundThread {
public:
explicit TCPStoreWorkerDaemon(Socket&& listenSocket);
~TCPStoreWorkerDaemon() override;
// Set the callback to run key change
void setCallback(std::string key, WatchKeyCallback cb);
void waitForCallbackRegistration() {
// Block until callback has been registered successfully
std::unique_lock<std::mutex> callbackRegistrationLock(
callbackRegistrationMutex_);
callbackRegisteredCV_.wait(
callbackRegistrationLock, [&] { return callbackRegisteredData_; });
// Reset payload for next callback
callbackRegisteredData_ = false;
}
void setCallbackRegistered() {
{
std::unique_lock<std::mutex> callbackRegistrationLock(
callbackRegistrationMutex_);
callbackRegisteredData_ = true;
}
callbackRegisteredCV_.notify_one();
}
private:
void run();
void callbackHandler(int socket);
// List of callbacks map each watched key
std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_{};
std::mutex keyToCallbacksMutex_{};
std::mutex callbackRegistrationMutex_{};
std::condition_variable callbackRegisteredCV_{};
bool callbackRegisteredData_ = false;
};
// TCPStoreListener class methods
TCPStoreWorkerDaemon::TCPStoreWorkerDaemon(Socket&& listenSocket)
: BackgroundThread{std::move(listenSocket)} {
daemonThread_ = std::thread{&TCPStoreWorkerDaemon::run, this};
}
TCPStoreWorkerDaemon::~TCPStoreWorkerDaemon() {
dispose();
}
void TCPStoreWorkerDaemon::setCallback(
std::string key,
WatchKeyCallback callback) {
const std::lock_guard<std::mutex> lock(keyToCallbacksMutex_);
keyToCallbacks_[key] = callback;
}
// Runs all the callbacks that the worker has registered
void TCPStoreWorkerDaemon::callbackHandler(int socket) {
auto watchResponse = tcputil::recvValue<WatchResponseType>(socket);
if (watchResponse == WatchResponseType::KEY_CALLBACK_REGISTERED) {
// Notify the waiting "watchKey" operation to return
setCallbackRegistered();
return;
}
std::string key = tcputil::recvString(socket);
std::vector<uint8_t> currentValueVec = tcputil::recvVector<uint8_t>(socket);
std::vector<uint8_t> newValueVec = tcputil::recvVector<uint8_t>(socket);
c10::optional<std::string> currentValue;
if (watchResponse == WatchResponseType::KEY_CREATED) {
assert(currentValueVec.empty());
currentValue = c10::nullopt;
} else {
currentValue = std::string(currentValueVec.begin(), currentValueVec.end());
}
c10::optional<std::string> newValue;
if (watchResponse == WatchResponseType::KEY_DELETED) {
assert(newValueVec.empty());
newValue = c10::nullopt;
} else {
newValue = std::string(newValueVec.begin(), newValueVec.end());
}
const std::lock_guard<std::mutex> lock(keyToCallbacksMutex_);
keyToCallbacks_.at(key)(currentValue, newValue);
}
#ifdef _WIN32
void TCPStoreWorkerDaemon::run() {
std::vector<struct pollfd> fds;
tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
while (true) {
// Check control and exit early if triggered
int res;
SYSCHECK_ERR_RETURN_NEG1(
res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
if (res == 0) {
auto rvPoll = WaitForSingleObject(ghStopEvent_, 0);
if (rvPoll != WAIT_TIMEOUT) {
break;
}
continue;
}
// if connection is closed gracefully by master, peeked data will return 0
char data;
int ret = recv(fds[0].fd, &data, 1, MSG_PEEK);
if (ret == 0) {
auto rvData = WaitForSingleObject(ghStopEvent_, 0);
if (rvData != WAIT_TIMEOUT) {
break;
}
continue;
}
// valid request, perform callback logic
callbackHandler(fds[0].fd);
}
}
#else
void TCPStoreWorkerDaemon::run() {
std::vector<struct pollfd> fds;
// Although we haven't found any documentation or literature describing this,
// we've seen cases that, under certain circumstances, the read end of the
// pipe won't receive POLLHUP when the write end is closed. However, under
// the same circumstances, writing to the pipe will guarantee POLLIN to be
// received on the read end.
//
// For more reliable termination, the main thread will write a byte to the
// pipe before closing it, and the background thread will poll for both
// POLLIN and POLLHUP.
tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP);
tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
while (true) {
SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
// Check control and exit early if triggered
// The pipe receives an event which tells us to shutdown the listener thread
if (fds[0].revents != 0) {
// The main thread will write a byte to the pipe then close it before
// joining the background thread
if (fds[0].revents & ~(POLLIN | POLLHUP)) {
throw std::system_error(
ECONNABORTED,
std::system_category(),
"Unexpected poll revent on the control pipe's reading fd: " +
std::to_string(fds[0].revents));
}
break;
}
// if connection is closed gracefully by master, peeked data will return 0
char data;
int ret = recv(fds[1].fd, &data, 1, MSG_PEEK);
if (ret == 0) {
continue;
}
// valid request, perform callback logic
callbackHandler(fds[1].fd);
}
}
#endif
} // namespace
// Manages the lifecycle of a server daemon.
class TCPServer {
public:
static std::shared_ptr<TCPServer> start(const TCPStoreOptions& opts);
std::uint16_t port() const noexcept {
return port_;
}
explicit TCPServer(
std::uint16_t port,
std::unique_ptr<TCPStoreMasterDaemon>&& daemon)
: port_{port}, daemon_{std::move(daemon)} {}
private:
std::uint16_t port_;
std::unique_ptr<TCPStoreMasterDaemon> daemon_;
// We store weak references to all TCPServers for which the caller requested
// multi-tenancy.
static std::unordered_map<std::uint16_t, std::weak_ptr<TCPServer>>
cachedServers_;
static std::mutex cache_mutex_;
};
std::unordered_map<std::uint16_t, std::weak_ptr<TCPServer>>
TCPServer::cachedServers_{};
std::mutex TCPServer::cache_mutex_{};
std::shared_ptr<TCPServer> TCPServer::start(const TCPStoreOptions& opts) {
auto startCore = [&opts]() {
Socket socket = Socket::listen(opts.port);
std::uint16_t port = socket.port();
auto daemon = std::make_unique<TCPStoreMasterDaemon>(std::move(socket));
return std::make_shared<TCPServer>(port, std::move(daemon));
};
std::shared_ptr<TCPServer> server{};
if (opts.multiTenant) {
std::lock_guard<std::mutex> guard{cache_mutex_};
// If the caller is okay with a multi-tenant store, first check if we
// already have a TCPServer running on the specified port.
if (opts.port > 0) {
auto pos = cachedServers_.find(opts.port);
if (pos != cachedServers_.end()) {
server = pos->second.lock();
if (server != nullptr) {
return server;
}
// Looks like the TCPStore has been disposed, make sure that we release
// the control block.
cachedServers_.erase(pos);
}
}
server = startCore();
cachedServers_.emplace(server->port(), server);
} else {
server = startCore();
}
return server;
}
class TCPClient {
public:
static std::unique_ptr<TCPClient> connect(
const SocketAddress& addr,
const TCPStoreOptions& opts);
void sendCommand(QueryType type) {
tcputil::sendValue<QueryType>(socket_.handle(), type);
}
void sendCommandForKey(QueryType type, const std::string& key);
void sendBytes(const std::vector<std::uint8_t>& value) {
tcputil::sendVector<std::uint8_t>(socket_.handle(), value);
}
void sendStrings(c10::ArrayRef<std::string> value);
template <typename T>
void sendValue(const T& value) {
tcputil::sendValue<T>(socket_.handle(), value);
}
std::vector<std::uint8_t> receiveBits() {
return tcputil::recvVector<std::uint8_t>(socket_.handle());
}
template <typename T>
T receiveValue() {
return tcputil::recvValue<T>(socket_.handle());
}
void setTimeout(std::chrono::milliseconds value);
explicit TCPClient(Socket&& socket) : socket_{std::move(socket)} {}
private:
Socket socket_;
};
std::unique_ptr<TCPClient> TCPClient::connect(
const SocketAddress& addr,
const TCPStoreOptions& opts) {
auto timeout = std::chrono::duration_cast<std::chrono::seconds>(opts.timeout);
Socket socket = Socket::connect(addr.host,
addr.port,
SocketOptions{}.connect_timeout(timeout));
return std::make_unique<TCPClient>(std::move(socket));
}
void TCPClient::sendCommandForKey(QueryType type, const std::string& key) {
tcputil::sendValue<QueryType>(socket_.handle(), type);
bool withValue = type == QueryType::SET || type == QueryType::COMPARE_SET ||
type == QueryType::ADD;
tcputil::sendString(socket_.handle(), key, withValue);
}
void TCPClient::sendStrings(c10::ArrayRef<std::string> value) {
std::size_t size = value.size();
tcputil::sendBytes<std::size_t>(socket_.handle(), &size, 1, size > 0);
if (value.empty()) {
return;
}
for (auto pos = value.begin(), last = value.end() - 1; pos <= last; ++pos) {
tcputil::sendString(socket_.handle(), *pos, pos != last);
}
}
void TCPClient::setTimeout(std::chrono::milliseconds value) {
if (value == std::chrono::milliseconds::zero()) {
return;
}
#ifdef _WIN32
struct timeval timeoutTV = {
static_cast<long>(value.count() / 1000),
static_cast<long>((value.count() % 1000) * 1000)};
#else
struct timeval timeoutTV = {
.tv_sec = value.count() / 1000,
.tv_usec = static_cast<suseconds_t>((value.count() % 1000) * 1000),
};
#endif
SYSCHECK_ERR_RETURN_NEG1(::setsockopt(
socket_.handle(),
SOL_SOCKET,
SO_RCVTIMEO,
reinterpret_cast<char*>(&timeoutTV),
sizeof(timeoutTV)));
}
class TCPCallbackClient {
public:
static std::unique_ptr<TCPCallbackClient> connect(
const SocketAddress& addr,
const TCPStoreOptions& opts);
void setCallback(const std::string& key, WatchKeyCallback callback);
explicit TCPCallbackClient(
int rawSocket,
std::unique_ptr<TCPStoreWorkerDaemon>&& daemon)
: rawSocket_{rawSocket}, daemon_{std::move(daemon)} {}
private:
int rawSocket_;
std::unique_ptr<TCPStoreWorkerDaemon> daemon_;
std::mutex mutex_;
};
std::unique_ptr<TCPCallbackClient> TCPCallbackClient::connect(
const SocketAddress& addr,
const TCPStoreOptions& opts) {
auto timeout = std::chrono::duration_cast<std::chrono::seconds>(opts.timeout);
Socket socket = Socket::connect(addr.host,
addr.port,
SocketOptions{}.connect_timeout(timeout));
int rawSocket = socket.handle();
auto daemon = std::make_unique<TCPStoreWorkerDaemon>(std::move(socket));
return std::make_unique<TCPCallbackClient>(rawSocket, std::move(daemon));
}
void TCPCallbackClient::setCallback(
const std::string& key,
WatchKeyCallback callback) {
std::lock_guard<std::mutex> guard{mutex_};
daemon_->setCallback(key, callback);
tcputil::sendValue<QueryType>(rawSocket_, QueryType::WATCH_KEY);
tcputil::sendString(rawSocket_, key);
daemon_->waitForCallbackRegistration();
}
} // namespace detail
using detail::Socket;
// TCPStore class methods
TCPStore::TCPStore(
const std::string& masterAddr,
std::uint16_t masterPort,
c10::optional<int> numWorkers,
bool isServer,
const std::chrono::milliseconds& timeout,
bool waitWorkers)
: TCPStore{
masterAddr,
TCPStoreOptions{
masterPort,
isServer,
numWorkers ? c10::optional<std::size_t>(*numWorkers)
: c10::nullopt,
waitWorkers,
timeout}} {}
TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts)
: Store{opts.timeout},
addr_{std::move(host)},
numWorkers_{opts.numWorkers} {
Socket::initialize();
if (opts.isServer) {
server_ = detail::TCPServer::start(opts);
addr_.port = server_->port();
} else {
addr_.port = opts.port;
}
client_ = detail::TCPClient::connect(addr_, opts);
if (opts.waitWorkers) {
waitForWorkers();
}
callbackClient_ = detail::TCPCallbackClient::connect(addr_, opts);
}
TCPStore::~TCPStore() = default;
void TCPStore::waitForWorkers() {
if (numWorkers_ == c10::nullopt) {
return;
}
incrementValueBy(initKey_, 1);
// Let server block until all workers have completed, this ensures that
// the server daemon thread is always running until the very end
if (server_) {
const auto start = std::chrono::steady_clock::now();
while (true) {
// TODO: Any chance to make this cleaner?
std::vector<uint8_t> value = doGet(initKey_);
auto buf = reinterpret_cast<const char*>(value.data());
auto len = value.size();
int numWorkersCompleted = std::stoi(std::string(buf, len));
if (numWorkersCompleted >= *numWorkers_) {
break;
}
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - start);
if (timeout_ != kNoTimeout && elapsed > timeout_) {
break;
}
/* sleep override */
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
}
void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
const std::lock_guard<std::mutex> lock(activeOpLock_);
client_->sendCommandForKey(detail::QueryType::SET, keyPrefix_ + key);
client_->sendBytes(data);
}
std::vector<uint8_t> TCPStore::compareSet(
const std::string& key,
const std::vector<uint8_t>& expectedValue,
const std::vector<uint8_t>& desiredValue) {
const std::lock_guard<std::mutex> lock(activeOpLock_);
client_->sendCommandForKey(detail::QueryType::COMPARE_SET, keyPrefix_ + key);
client_->sendBytes(expectedValue);
client_->sendBytes(desiredValue);
return client_->receiveBits();
}
std::vector<uint8_t> TCPStore::get(const std::string& key) {
const std::lock_guard<std::mutex> lock(activeOpLock_);
return doGet(keyPrefix_ + key);
}
std::vector<uint8_t> TCPStore::doGet(const std::string& key) {
doWait(key, timeout_);
client_->sendCommandForKey(detail::QueryType::GET, key);
return client_->receiveBits();
}
int64_t TCPStore::add(const std::string& key, int64_t value) {
const std::lock_guard<std::mutex> lock(activeOpLock_);
return incrementValueBy(keyPrefix_ + key, value);
}
bool TCPStore::deleteKey(const std::string& key) {
const std::lock_guard<std::mutex> lock(activeOpLock_);
client_->sendCommandForKey(detail::QueryType::DELETE_KEY, keyPrefix_ + key);
auto numDeleted = client_->receiveValue<std::int64_t>();
return numDeleted == 1;
}
void TCPStore::watchKey(const std::string& key, WatchKeyCallback callback) {
const std::lock_guard<std::mutex> lock(activeOpLock_);
callbackClient_->setCallback(keyPrefix_ + key, callback);
}
int64_t TCPStore::incrementValueBy(const std::string& key, int64_t delta) {
client_->sendCommandForKey(detail::QueryType::ADD, key);
client_->sendValue<std::int64_t>(delta);
return client_->receiveValue<std::int64_t>();
}
int64_t TCPStore::getNumKeys() {
const std::lock_guard<std::mutex> lock(activeOpLock_);
client_->sendCommand(detail::QueryType::GETNUMKEYS);
return client_->receiveValue<std::int64_t>();
}
bool TCPStore::check(const std::vector<std::string>& keys) {
const std::lock_guard<std::mutex> lock(activeOpLock_);
std::vector<std::string> prefixedKeys{};
prefixedKeys.reserve(keys.size());
for (const std::string& key : keys) {
prefixedKeys.emplace_back(keyPrefix_ + key);
}
client_->sendCommand(detail::QueryType::CHECK);
client_->sendStrings(prefixedKeys);
auto response = client_->receiveValue<detail::CheckResponseType>();
if (response == detail::CheckResponseType::READY) {
return true;
}
if (response == detail::CheckResponseType::NOT_READY) {
return false;
}
TORCH_CHECK(false, "ready or not_ready response expected");
}
void TCPStore::wait(const std::vector<std::string>& keys) {
wait(keys, timeout_);
}
void TCPStore::wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) {
const std::lock_guard<std::mutex> lock(activeOpLock_);
std::vector<std::string> prefixedKeys{};
prefixedKeys.reserve(keys.size());
for (const std::string& key : keys) {
prefixedKeys.emplace_back(keyPrefix_ + key);
}
doWait(prefixedKeys, timeout);
}
void TCPStore::doWait(
c10::ArrayRef<std::string> keys,
std::chrono::milliseconds timeout) {
// TODO: Should we revert to the original timeout at the end of the call?
client_->setTimeout(timeout);
client_->sendCommand(detail::QueryType::WAIT);
client_->sendStrings(keys);
auto response = client_->receiveValue<detail::WaitResponseType>();
if (response != detail::WaitResponseType::STOP_WAITING) {
TORCH_CHECK(false, "Stop_waiting response is expected");
}
}
} // namespace c10d