mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: This PR fixes a race condition for TCP init method, when master rank can exit earlier than slave ranks and thus the TCP daemon thread gets shutdown before other slaves are able to access it. This will let every rank (process) write a special key to the store to mark that they are completed (and thus about to exit). The master rank (who is the server) will always wait until all the ranks to complete before complete itself. This should fix: https://github.com/pytorch/pytorch/issues/15638 Tested using the repro of https://github.com/pytorch/pytorch/issues/15638 and works fine. Also test_distributed and test_c10d should have already had this coverage. I had to make rendezvous test in c10d the world size of 1, since it is a single process code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15684 Differential Revision: D13570904 Pulled By: teng-li fbshipit-source-id: 34f3bc471204bbd29320df359347ad5561c6b589
428 lines
13 KiB
C++
428 lines
13 KiB
C++
#include <c10d/TCPStore.hpp>
|
|
|
|
#include <poll.h>
|
|
|
|
#include <unistd.h>
|
|
#include <algorithm>
|
|
#include <system_error>
|
|
|
|
namespace c10d {
|
|
|
|
namespace {
|
|
|
|
enum class QueryType : uint8_t { SET, GET, ADD, CHECK, WAIT };
|
|
|
|
enum class CheckResponseType : uint8_t { READY, NOT_READY };
|
|
|
|
enum class WaitResponseType : uint8_t { STOP_WAITING };
|
|
|
|
} // anonymous namespace
|
|
|
|
// TCPStoreDaemon class methods
|
|
// Simply start the daemon thread
|
|
TCPStoreDaemon::TCPStoreDaemon(int storeListenSocket)
|
|
: storeListenSocket_(storeListenSocket) {
|
|
// Use control pipe to signal instance destruction to the daemon thread.
|
|
if (pipe(controlPipeFd_.data()) == -1) {
|
|
throw std::runtime_error(
|
|
"Failed to create the control pipe to start the "
|
|
"TCPStoreDaemon run");
|
|
}
|
|
daemonThread_ = std::thread(&TCPStoreDaemon::run, this);
|
|
}
|
|
|
|
TCPStoreDaemon::~TCPStoreDaemon() {
|
|
// Stop the run
|
|
stop();
|
|
// Join the thread
|
|
join();
|
|
// Close unclosed sockets
|
|
for (auto socket : sockets_) {
|
|
if (socket != -1) {
|
|
::close(socket);
|
|
}
|
|
}
|
|
// Now close the rest control pipe
|
|
for (auto fd : controlPipeFd_) {
|
|
if (fd != -1) {
|
|
::close(fd);
|
|
}
|
|
}
|
|
}
|
|
|
|
void TCPStoreDaemon::join() {
|
|
daemonThread_.join();
|
|
}
|
|
|
|
void TCPStoreDaemon::run() {
|
|
std::vector<struct pollfd> fds;
|
|
fds.push_back({.fd = storeListenSocket_, .events = POLLIN});
|
|
// Push the read end of the pipe to signal the stopping of the daemon run
|
|
fds.push_back({.fd = controlPipeFd_[0], .events = POLLHUP});
|
|
|
|
// receive the queries
|
|
bool finished = false;
|
|
while (!finished) {
|
|
for (size_t i = 0; i < sockets_.size(); i++) {
|
|
fds[i].revents = 0;
|
|
}
|
|
|
|
SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
|
|
|
|
// TCPStore's listening socket has an event and it should now be able to
|
|
// accept new connections.
|
|
if (fds[0].revents != 0) {
|
|
if (fds[0].revents ^ POLLIN) {
|
|
throw std::system_error(
|
|
ECONNABORTED,
|
|
std::system_category(),
|
|
"Unexpected poll revent on the master's listening socket: " +
|
|
std::to_string(fds[0].revents));
|
|
}
|
|
int sockFd = std::get<0>(tcputil::accept(storeListenSocket_));
|
|
sockets_.push_back(sockFd);
|
|
fds.push_back({.fd = sockFd, .events = POLLIN});
|
|
}
|
|
// The pipe receives an event which tells us to shutdown the daemon
|
|
if (fds[1].revents != 0) {
|
|
// Will be POLLUP when the pipe is closed
|
|
if (fds[1].revents ^ POLLHUP) {
|
|
throw std::system_error(
|
|
ECONNABORTED,
|
|
std::system_category(),
|
|
"Unexpected poll revent on the control pipe's reading fd: " +
|
|
std::to_string(fds[1].revents));
|
|
}
|
|
finished = true;
|
|
break;
|
|
}
|
|
// Skipping the fds[0] and fds[1],
|
|
// fds[0] is master's listening socket
|
|
// fds[1] is control pipe's reading fd
|
|
for (size_t fdIdx = 2; fdIdx < fds.size(); ++fdIdx) {
|
|
if (fds[fdIdx].revents == 0) {
|
|
continue;
|
|
}
|
|
|
|
if (fds[fdIdx].revents ^ POLLIN) {
|
|
throw std::system_error(
|
|
ECONNABORTED,
|
|
std::system_category(),
|
|
"Unexpected poll revent: " + std::to_string(fds[fdIdx].revents) +
|
|
" on socket: " + std::to_string(fds[fdIdx].fd));
|
|
}
|
|
// Now query the socket that has the event
|
|
try {
|
|
query(fds[fdIdx].fd);
|
|
} catch (...) {
|
|
// There was an error when processing query. Probably an exception
|
|
// occurred in recv/send what would indicate that socket on the other
|
|
// side has been closed. If the closing was due to normal exit, then
|
|
// the store should continue executing. Otherwise, if it was different
|
|
// exception, other connections will get an exception once they try to
|
|
// use the store. We will go ahead and close this connection whenever
|
|
// we hit an exception here.
|
|
::close(fds[fdIdx].fd);
|
|
|
|
// Remove all the tracking state of the close FD
|
|
for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
|
|
for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
|
|
if (*vecIt == fds[fdIdx].fd) {
|
|
vecIt = it->second.erase(vecIt);
|
|
} else {
|
|
++vecIt;
|
|
}
|
|
}
|
|
if (it->second.size() == 0) {
|
|
it = waitingSockets_.erase(it);
|
|
} else {
|
|
++it;
|
|
}
|
|
}
|
|
for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) {
|
|
if (it->first == fds[fdIdx].fd) {
|
|
it = keysAwaited_.erase(it);
|
|
} else {
|
|
++it;
|
|
}
|
|
}
|
|
fds.erase(fds.begin() + fdIdx);
|
|
sockets_.erase(sockets_.begin() + fdIdx - 2);
|
|
--fdIdx;
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void TCPStoreDaemon::stop() {
|
|
if (controlPipeFd_[1] != -1) {
|
|
// close the write end of the pipe
|
|
::close(controlPipeFd_[1]);
|
|
controlPipeFd_[1] = -1;
|
|
}
|
|
}
|
|
|
|
// query communicates with the worker. The format
|
|
// of the query is as follows:
|
|
// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
|
|
// or, in the case of wait
|
|
// type of query | number of args | size of arg1 | arg1 | ...
|
|
void TCPStoreDaemon::query(int socket) {
|
|
QueryType qt;
|
|
tcputil::recvBytes<QueryType>(socket, &qt, 1);
|
|
|
|
if (qt == QueryType::SET) {
|
|
setHandler(socket);
|
|
|
|
} else if (qt == QueryType::ADD) {
|
|
addHandler(socket);
|
|
|
|
} else if (qt == QueryType::GET) {
|
|
getHandler(socket);
|
|
|
|
} else if (qt == QueryType::CHECK) {
|
|
checkHandler(socket);
|
|
|
|
} else if (qt == QueryType::WAIT) {
|
|
waitHandler(socket);
|
|
|
|
} else {
|
|
throw std::runtime_error("Unexpected query type");
|
|
}
|
|
}
|
|
|
|
void TCPStoreDaemon::wakeupWaitingClients(const std::string& key) {
|
|
auto socketsToWait = waitingSockets_.find(key);
|
|
if (socketsToWait != waitingSockets_.end()) {
|
|
for (int socket : socketsToWait->second) {
|
|
if (--keysAwaited_[socket] == 0) {
|
|
tcputil::sendValue<WaitResponseType>(
|
|
socket, WaitResponseType::STOP_WAITING);
|
|
}
|
|
}
|
|
waitingSockets_.erase(socketsToWait);
|
|
}
|
|
}
|
|
|
|
void TCPStoreDaemon::setHandler(int socket) {
|
|
std::string key = tcputil::recvString(socket);
|
|
tcpStore_[key] = tcputil::recvVector<uint8_t>(socket);
|
|
// On "set", wake up all clients that have been waiting
|
|
wakeupWaitingClients(key);
|
|
}
|
|
|
|
void TCPStoreDaemon::addHandler(int socket) {
|
|
std::string key = tcputil::recvString(socket);
|
|
int64_t addVal = tcputil::recvValue<int64_t>(socket);
|
|
|
|
if (tcpStore_.find(key) != tcpStore_.end()) {
|
|
auto buf = reinterpret_cast<const char*>(tcpStore_[key].data());
|
|
auto len = tcpStore_[key].size();
|
|
addVal += std::stoll(std::string(buf, len));
|
|
}
|
|
auto addValStr = std::to_string(addVal);
|
|
tcpStore_[key] = std::vector<uint8_t>(addValStr.begin(), addValStr.end());
|
|
// Now send the new value
|
|
tcputil::sendValue<int64_t>(socket, addVal);
|
|
// On "add", wake up all clients that have been waiting
|
|
wakeupWaitingClients(key);
|
|
}
|
|
|
|
void TCPStoreDaemon::getHandler(int socket) const {
|
|
std::string key = tcputil::recvString(socket);
|
|
auto data = tcpStore_.at(key);
|
|
tcputil::sendVector<uint8_t>(socket, data);
|
|
}
|
|
|
|
void TCPStoreDaemon::checkHandler(int socket) const {
|
|
SizeType nargs;
|
|
tcputil::recvBytes<SizeType>(socket, &nargs, 1);
|
|
std::vector<std::string> keys(nargs);
|
|
for (size_t i = 0; i < nargs; i++) {
|
|
keys[i] = tcputil::recvString(socket);
|
|
}
|
|
// Now we have received all the keys
|
|
if (checkKeys(keys)) {
|
|
tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::READY);
|
|
} else {
|
|
tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::NOT_READY);
|
|
}
|
|
}
|
|
|
|
void TCPStoreDaemon::waitHandler(int socket) {
|
|
SizeType nargs;
|
|
tcputil::recvBytes<SizeType>(socket, &nargs, 1);
|
|
std::vector<std::string> keys(nargs);
|
|
for (size_t i = 0; i < nargs; i++) {
|
|
keys[i] = tcputil::recvString(socket);
|
|
}
|
|
if (checkKeys(keys)) {
|
|
tcputil::sendValue<WaitResponseType>(
|
|
socket, WaitResponseType::STOP_WAITING);
|
|
} else {
|
|
for (auto& key : keys) {
|
|
waitingSockets_[key].push_back(socket);
|
|
}
|
|
keysAwaited_[socket] = keys.size();
|
|
}
|
|
}
|
|
|
|
bool TCPStoreDaemon::checkKeys(const std::vector<std::string>& keys) const {
|
|
return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) {
|
|
return tcpStore_.count(s) > 0;
|
|
});
|
|
}
|
|
|
|
// TCPStore class methods
|
|
TCPStore::TCPStore(
|
|
const std::string& masterAddr,
|
|
PortType masterPort,
|
|
int numWorkers,
|
|
bool isServer)
|
|
: isServer_(isServer),
|
|
tcpStoreAddr_(masterAddr),
|
|
tcpStorePort_(masterPort),
|
|
numWorkers_(numWorkers),
|
|
initKey_("init/"),
|
|
regularPrefix_("/") {
|
|
if (isServer_) {
|
|
// Opening up the listening socket
|
|
std::tie(masterListenSocket_, std::ignore) = tcputil::listen(masterPort);
|
|
// Now start the daemon
|
|
tcpStoreDaemon_ = std::unique_ptr<TCPStoreDaemon>(
|
|
new TCPStoreDaemon(masterListenSocket_));
|
|
}
|
|
// Connect to the daemon
|
|
storeSocket_ = tcputil::connect(tcpStoreAddr_, tcpStorePort_);
|
|
|
|
waitForWorkers_();
|
|
}
|
|
|
|
TCPStore::~TCPStore() {
|
|
::close(storeSocket_);
|
|
if (isServer_) {
|
|
// Store daemon should end because of closed connection.
|
|
// daemon destructor should join the thread
|
|
tcpStoreDaemon_.reset(nullptr);
|
|
::close(masterListenSocket_);
|
|
}
|
|
}
|
|
|
|
void TCPStore::waitForWorkers_() {
|
|
addHelper_(initKey_, 1);
|
|
// Let server block until all workers have completed, this ensures that
|
|
// the server daemon thread is always running until the very end
|
|
if (isServer_) {
|
|
const auto start = std::chrono::steady_clock::now();
|
|
while (true) {
|
|
std::vector<uint8_t> value = getHelper_(initKey_);
|
|
auto buf = reinterpret_cast<const char*>(value.data());
|
|
auto len = value.size();
|
|
int numWorkersCompleted = std::stoi(std::string(buf, len));
|
|
if (numWorkersCompleted >= numWorkers_) {
|
|
break;
|
|
}
|
|
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
|
std::chrono::steady_clock::now() - start);
|
|
if (timeout_ != kNoTimeout && elapsed > timeout_) {
|
|
break;
|
|
}
|
|
/* sleep override */
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
}
|
|
}
|
|
}
|
|
|
|
void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
|
|
std::string regKey = regularPrefix_ + key;
|
|
tcputil::sendValue<QueryType>(storeSocket_, QueryType::SET);
|
|
tcputil::sendString(storeSocket_, regKey, true);
|
|
tcputil::sendVector<uint8_t>(storeSocket_, data);
|
|
}
|
|
|
|
std::vector<uint8_t> TCPStore::get(const std::string& key) {
|
|
std::string regKey = regularPrefix_ + key;
|
|
return getHelper_(regKey);
|
|
}
|
|
|
|
std::vector<uint8_t> TCPStore::getHelper_(const std::string& key) {
|
|
waitHelper_({key}, timeout_);
|
|
tcputil::sendValue<QueryType>(storeSocket_, QueryType::GET);
|
|
tcputil::sendString(storeSocket_, key);
|
|
return tcputil::recvVector<uint8_t>(storeSocket_);
|
|
}
|
|
|
|
int64_t TCPStore::add(const std::string& key, int64_t value) {
|
|
std::string regKey = regularPrefix_ + key;
|
|
return addHelper_(regKey, value);
|
|
}
|
|
|
|
int64_t TCPStore::addHelper_(const std::string& key, int64_t value) {
|
|
tcputil::sendValue<QueryType>(storeSocket_, QueryType::ADD);
|
|
tcputil::sendString(storeSocket_, key, true);
|
|
tcputil::sendValue<int64_t>(storeSocket_, value);
|
|
return tcputil::recvValue<int64_t>(storeSocket_);
|
|
}
|
|
|
|
bool TCPStore::check(const std::vector<std::string>& keys) {
|
|
tcputil::sendValue<QueryType>(storeSocket_, QueryType::CHECK);
|
|
SizeType nkeys = keys.size();
|
|
tcputil::sendBytes<SizeType>(storeSocket_, &nkeys, 1, (nkeys > 0));
|
|
for (size_t i = 0; i < nkeys; i++) {
|
|
std::string regKey = regularPrefix_ + keys[i];
|
|
tcputil::sendString(storeSocket_, regKey, (i != (nkeys - 1)));
|
|
}
|
|
auto checkResponse = tcputil::recvValue<CheckResponseType>(storeSocket_);
|
|
if (checkResponse == CheckResponseType::READY) {
|
|
return true;
|
|
} else if (checkResponse == CheckResponseType::NOT_READY) {
|
|
return false;
|
|
} else {
|
|
throw std::runtime_error("ready or not_ready response expected");
|
|
}
|
|
}
|
|
|
|
void TCPStore::wait(const std::vector<std::string>& keys) {
|
|
wait(keys, timeout_);
|
|
}
|
|
|
|
void TCPStore::wait(
|
|
const std::vector<std::string>& keys,
|
|
const std::chrono::milliseconds& timeout) {
|
|
std::vector<std::string> regKeys;
|
|
regKeys.resize(keys.size());
|
|
for (size_t i = 0; i < keys.size(); ++i) {
|
|
regKeys[i] = regularPrefix_ + keys[i];
|
|
}
|
|
waitHelper_(regKeys, timeout);
|
|
}
|
|
|
|
void TCPStore::waitHelper_(
|
|
const std::vector<std::string>& keys,
|
|
const std::chrono::milliseconds& timeout) {
|
|
// Set the socket timeout if there is a wait timeout
|
|
if (timeout != kNoTimeout) {
|
|
struct timeval timeoutTV = {.tv_sec = timeout.count() / 1000,
|
|
.tv_usec = (timeout.count() % 1000) * 1000};
|
|
SYSCHECK_ERR_RETURN_NEG1(::setsockopt(
|
|
storeSocket_,
|
|
SOL_SOCKET,
|
|
SO_RCVTIMEO,
|
|
reinterpret_cast<char*>(&timeoutTV),
|
|
sizeof(timeoutTV)));
|
|
}
|
|
tcputil::sendValue<QueryType>(storeSocket_, QueryType::WAIT);
|
|
SizeType nkeys = keys.size();
|
|
tcputil::sendBytes<SizeType>(storeSocket_, &nkeys, 1, (nkeys > 0));
|
|
for (size_t i = 0; i < nkeys; i++) {
|
|
tcputil::sendString(storeSocket_, keys[i], (i != (nkeys - 1)));
|
|
}
|
|
auto waitResponse = tcputil::recvValue<WaitResponseType>(storeSocket_);
|
|
if (waitResponse != WaitResponseType::STOP_WAITING) {
|
|
throw std::runtime_error("Stop_waiting response is expected");
|
|
}
|
|
}
|
|
|
|
} // namespace c10d
|