mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Fix] add validation logics to TCPStore queries (#107607)
This PR fixes #106294. Due to the lack of request validation mechanism, TCPStore in torch mistakenly treats nmap scan messages as valid query messages, which leads to DDP OOM. The simple solution enforces the very first query from a client is a validation query with a predefined magic number. If the validation fails, the server will terminate the connection. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107607 Approved by: https://github.com/cbalioglu, https://github.com/XilunWu
This commit is contained in:
committed by
PyTorch MergeBot
parent
56e514aefb
commit
7c4e49ec80
@ -341,6 +341,9 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts)
|
||||
// TCP connection established
|
||||
C10D_DEBUG("TCP client connected to host {}:{}", addr_.host, addr_.port);
|
||||
|
||||
// client's first query for validation
|
||||
validate();
|
||||
|
||||
if (opts.waitWorkers) {
|
||||
waitForWorkers();
|
||||
}
|
||||
@ -386,6 +389,13 @@ void TCPStore::waitForWorkers() {
|
||||
}
|
||||
}
|
||||
|
||||
void TCPStore::validate(void) {
|
||||
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
||||
detail::SendBuffer buffer(*client_, detail::QueryType::VALIDATE);
|
||||
buffer.appendValue<std::uint32_t>(c10d::detail::validationMagicNumber);
|
||||
buffer.flush();
|
||||
}
|
||||
|
||||
void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
|
||||
detail::timing_guard tguard(clientCounters_["set"]);
|
||||
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
||||
|
@ -138,6 +138,8 @@ class TORCH_API TCPStore : public Store {
|
||||
private:
|
||||
int64_t incrementValueBy(const std::string& key, int64_t delta);
|
||||
|
||||
void validate(void);
|
||||
|
||||
std::vector<uint8_t> doGet(const std::string& key);
|
||||
|
||||
void doWait(
|
||||
|
@ -73,6 +73,7 @@ class TCPStoreMasterDaemon : public BackgroundThread {
|
||||
|
||||
// The master runs on a single thread so only
|
||||
// one handler can be executed at a time
|
||||
void validateHandler(int socket);
|
||||
void setHandler(int socket);
|
||||
void compareSetHandler(int socket);
|
||||
void addHandler(int socket);
|
||||
@ -85,6 +86,9 @@ class TCPStoreMasterDaemon : public BackgroundThread {
|
||||
void multiGetHandler(int socket);
|
||||
void multiSetHandler(int socket);
|
||||
void cancelWaitHandler(int socket);
|
||||
void addMiscellaneousSocket(int socket);
|
||||
void removeMiscellaneousSocket(int socket);
|
||||
bool isMiscellaneousSocket(int socket);
|
||||
|
||||
bool checkKeys(const std::vector<std::string>& keys) const;
|
||||
// Helper function to alerts waiting workers, used in setHandler, getHandler
|
||||
@ -96,6 +100,8 @@ class TCPStoreMasterDaemon : public BackgroundThread {
|
||||
std::unordered_map<std::string, std::vector<int>> waitingSockets_;
|
||||
// From socket -> number of keys awaited
|
||||
std::unordered_map<int, size_t> keysAwaited_;
|
||||
// miscellaneous sockets
|
||||
std::unordered_set<int> miscellaneousSockets_;
|
||||
|
||||
Socket storeListenSocket_;
|
||||
std::vector<Socket> sockets_{};
|
||||
@ -250,7 +256,17 @@ void TCPStoreMasterDaemon::clearSocketWaitState(int socket) {
|
||||
void TCPStoreMasterDaemon::query(int socket) {
|
||||
QueryType qt;
|
||||
tcputil::recvBytes<QueryType>(socket, &qt, 1);
|
||||
if (qt == QueryType::SET) {
|
||||
|
||||
if (isMiscellaneousSocket(socket)) {
|
||||
removeMiscellaneousSocket(socket);
|
||||
if (qt == QueryType::VALIDATE) {
|
||||
validateHandler(socket);
|
||||
} else {
|
||||
// real miscellaneous client: the first msg is not VALIDATE
|
||||
TORCH_CHECK(
|
||||
false, "Miscellaneous client without VALIDATE query is detected");
|
||||
}
|
||||
} else if (qt == QueryType::SET) {
|
||||
setHandler(socket);
|
||||
|
||||
} else if (qt == QueryType::COMPARE_SET) {
|
||||
@ -307,6 +323,16 @@ void TCPStoreMasterDaemon::doSet(
|
||||
wakeupWaitingClients(key);
|
||||
}
|
||||
|
||||
void TCPStoreMasterDaemon::validateHandler(int socket) {
|
||||
uint32_t validateNumber;
|
||||
tcputil::recvBytes<uint32_t>(socket, &validateNumber, 1);
|
||||
if (validateNumber != detail::validationMagicNumber) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Miscellaneous client with incorrect VALIDATE query is detected");
|
||||
}
|
||||
}
|
||||
|
||||
void TCPStoreMasterDaemon::setHandler(int socket) {
|
||||
std::string key = tcputil::recvString(socket);
|
||||
std::vector<uint8_t> newData = tcputil::recvVector<uint8_t>(socket);
|
||||
@ -459,6 +485,23 @@ bool TCPStoreMasterDaemon::checkKeys(
|
||||
});
|
||||
}
|
||||
|
||||
void TCPStoreMasterDaemon::addMiscellaneousSocket(int socket) {
|
||||
if (miscellaneousSockets_.find(socket) == miscellaneousSockets_.end()) {
|
||||
miscellaneousSockets_.insert(socket);
|
||||
}
|
||||
}
|
||||
|
||||
void TCPStoreMasterDaemon::removeMiscellaneousSocket(int socket) {
|
||||
auto it = miscellaneousSockets_.find(socket);
|
||||
if (it != miscellaneousSockets_.end()) {
|
||||
miscellaneousSockets_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
bool TCPStoreMasterDaemon::isMiscellaneousSocket(int socket) {
|
||||
return miscellaneousSockets_.find(socket) != miscellaneousSockets_.end();
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
void TCPStoreMasterDaemon::run() {
|
||||
std::vector<struct pollfd> fds;
|
||||
@ -537,6 +580,8 @@ void TCPStoreMasterDaemon::run() {
|
||||
int rawSocket = socket.handle();
|
||||
sockets_.emplace_back(std::move(socket));
|
||||
tcputil::addPollfd(fds, rawSocket, POLLIN);
|
||||
// all clients are miscellaneous before getting its validation query
|
||||
addMiscellaneousSocket(rawSocket);
|
||||
}
|
||||
|
||||
// The pipe receives an event which tells us to shutdown the daemon
|
||||
|
@ -18,7 +18,11 @@
|
||||
namespace c10d {
|
||||
namespace detail {
|
||||
|
||||
// Magic number for client validation.
|
||||
static const uint32_t validationMagicNumber = 0x3C85F7CE;
|
||||
|
||||
enum class QueryType : uint8_t {
|
||||
VALIDATE,
|
||||
SET,
|
||||
COMPARE_SET,
|
||||
GET,
|
||||
|
@ -578,6 +578,7 @@ class LibUVStoreDaemon : public BackgroundThread {
|
||||
void registerClient(c10::intrusive_ptr<UvHandle> client);
|
||||
void unregisterClient(c10::intrusive_ptr<UvHandle> client);
|
||||
void clearClientWaitState(c10::intrusive_ptr<UvHandle> client);
|
||||
bool isMiscellaneousClient(c10::intrusive_ptr<UvHandle> client);
|
||||
|
||||
uint16_t get_socket_port(uv_tcp_t* handle);
|
||||
void init(const TCPStoreOptions& opts);
|
||||
@ -598,6 +599,7 @@ class LibUVStoreDaemon : public BackgroundThread {
|
||||
// From socket -> number of keys awaited
|
||||
std::unordered_map<c10::intrusive_ptr<UvHandle>, size_t> keysAwaited_;
|
||||
std::unordered_set<c10::intrusive_ptr<UvHandle>> clients_;
|
||||
std::unordered_set<c10::intrusive_ptr<UvHandle>> miscellaneousClients_;
|
||||
int port_;
|
||||
|
||||
static LibUVStoreDaemon& from_uv(uv_handle_t* stream) {
|
||||
@ -635,67 +637,84 @@ class UvClient : public UvTcpSocket {
|
||||
uint8_t command = -1;
|
||||
if (!stream.read1(command))
|
||||
break;
|
||||
switch ((QueryType)command) {
|
||||
case QueryType::SET:
|
||||
if (!parse_set_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::COMPARE_SET:
|
||||
if (!parse_compare_set_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::GET:
|
||||
if (!parse_get_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::ADD:
|
||||
if (!parse_add_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::CHECK:
|
||||
if (!parse_check_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::WAIT:
|
||||
if (!parse_wait_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::GETNUMKEYS:
|
||||
if (!parse_getnumkeys_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::DELETE_KEY:
|
||||
if (!parse_delete_key_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::APPEND:
|
||||
if (!parse_append_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::MULTI_GET:
|
||||
if (!parse_multi_get_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::MULTI_SET:
|
||||
if (!parse_multi_set_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::CANCEL_WAIT:
|
||||
if (!parse_cancel_wait_command())
|
||||
return;
|
||||
break;
|
||||
default:
|
||||
C10D_DEBUG(
|
||||
"Client sent invalid command. client:{} command:{}",
|
||||
(void*)this,
|
||||
(int)command);
|
||||
close();
|
||||
if (store->isMiscellaneousClient(iptr())) {
|
||||
if ((QueryType)command != QueryType::VALIDATE)
|
||||
return;
|
||||
if (!parse_validate_command())
|
||||
return;
|
||||
} else {
|
||||
switch ((QueryType)command) {
|
||||
case QueryType::SET:
|
||||
if (!parse_set_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::COMPARE_SET:
|
||||
if (!parse_compare_set_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::GET:
|
||||
if (!parse_get_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::ADD:
|
||||
if (!parse_add_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::CHECK:
|
||||
if (!parse_check_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::WAIT:
|
||||
if (!parse_wait_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::GETNUMKEYS:
|
||||
if (!parse_getnumkeys_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::DELETE_KEY:
|
||||
if (!parse_delete_key_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::APPEND:
|
||||
if (!parse_append_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::MULTI_GET:
|
||||
if (!parse_multi_get_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::MULTI_SET:
|
||||
if (!parse_multi_set_command())
|
||||
return;
|
||||
break;
|
||||
case QueryType::CANCEL_WAIT:
|
||||
if (!parse_cancel_wait_command())
|
||||
return;
|
||||
break;
|
||||
default:
|
||||
C10D_DEBUG(
|
||||
"Client sent invalid command. client:{} command:{}",
|
||||
(void*)this,
|
||||
(int)command);
|
||||
close();
|
||||
return;
|
||||
}
|
||||
}
|
||||
stream.commit();
|
||||
}
|
||||
}
|
||||
|
||||
bool parse_validate_command() {
|
||||
uint32_t validateNumber;
|
||||
if (!stream.read_value(validateNumber))
|
||||
return false;
|
||||
|
||||
if (validateNumber != c10d::detail::validationMagicNumber)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool parse_set_command() {
|
||||
std::string key;
|
||||
if (!stream.read_key(key))
|
||||
@ -1061,12 +1080,25 @@ void LibUVStoreDaemon::stop() {
|
||||
}
|
||||
}
|
||||
|
||||
bool LibUVStoreDaemon::isMiscellaneousClient(
|
||||
c10::intrusive_ptr<UvHandle> client) {
|
||||
if (miscellaneousClients_.find(client) != miscellaneousClients_.end()) {
|
||||
miscellaneousClients_.erase(client);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void LibUVStoreDaemon::registerClient(c10::intrusive_ptr<UvHandle> client) {
|
||||
clients_.insert(client);
|
||||
miscellaneousClients_.insert(client);
|
||||
}
|
||||
|
||||
void LibUVStoreDaemon::unregisterClient(c10::intrusive_ptr<UvHandle> client) {
|
||||
clients_.erase(client);
|
||||
if (miscellaneousClients_.find(client) != miscellaneousClients_.end()) {
|
||||
miscellaneousClients_.erase(client);
|
||||
}
|
||||
clearClientWaitState(client);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user