mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
While looking at TCPStore code again and found it confusing that we still keep the deprecated constructor for TCPStore in cpp while we don't expose it in python via pybind already. I checked both internal and external, all use cases in cpp (aside from unit test fixed in this PR) already moved to using option. So let's remove this legacy constructor to avoid confusion. Differential Revision: [D62653634](https://our.internmc.facebook.com/intern/diff/D62653634) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136004 Approved by: https://github.com/Skylion007, https://github.com/XilunWu
661 lines
19 KiB
C++
661 lines
19 KiB
C++
#include <c10/util/WaitCounter.h>
|
|
#include <c10/util/irange.h>
|
|
#include <fmt/format.h>
|
|
#include <fmt/ranges.h>
|
|
#include <torch/csrc/distributed/c10d/Backoff.hpp>
|
|
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
|
|
#include <torch/csrc/distributed/c10d/TCPStoreBackend.hpp>
|
|
#include <torch/csrc/distributed/c10d/logging.h>
|
|
|
|
#include <fcntl.h>
|
|
#include <chrono>
|
|
#include <fstream>
|
|
#include <random>
|
|
#include <thread>
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
|
|
#ifdef _WIN32
|
|
#include <io.h>
|
|
#include <winsock2.h>
|
|
#else
|
|
#include <poll.h>
|
|
#include <unistd.h>
|
|
#endif
|
|
|
|
#ifdef _WIN32
|
|
#include <torch/csrc/distributed/c10d/WinSockUtils.hpp>
|
|
#else
|
|
#include <torch/csrc/distributed/c10d/UnixSockUtils.hpp>
|
|
#endif
|
|
|
|
#include <torch/csrc/distributed/c10d/socket.h>
|
|
|
|
namespace c10d {
|
|
namespace detail {
|
|
|
|
// Manages the lifecycle of a server daemon.
|
|
class TCPServer {
|
|
public:
|
|
static std::shared_ptr<TCPServer> start(const TCPStoreOptions& opts);
|
|
|
|
std::uint16_t port() const noexcept {
|
|
return port_;
|
|
}
|
|
|
|
explicit TCPServer(
|
|
std::uint16_t port,
|
|
std::unique_ptr<BackgroundThread>&& daemon)
|
|
: port_{port}, daemon_{std::move(daemon)} {}
|
|
|
|
std::string repr() const {
|
|
return fmt::format("TCPServer(port={})", port_);
|
|
}
|
|
|
|
private:
|
|
std::uint16_t port_;
|
|
std::unique_ptr<BackgroundThread> daemon_;
|
|
|
|
// We store weak references to all TCPServers for which the caller requested
|
|
// multi-tenancy.
|
|
static std::unordered_map<std::uint16_t, std::weak_ptr<TCPServer>>
|
|
cachedServers_;
|
|
|
|
static std::mutex cache_mutex_;
|
|
};
|
|
|
|
std::unordered_map<std::uint16_t, std::weak_ptr<TCPServer>>
|
|
TCPServer::cachedServers_{};
|
|
|
|
std::mutex TCPServer::cache_mutex_{};
|
|
|
|
std::shared_ptr<TCPServer> TCPServer::start(const TCPStoreOptions& opts) {
|
|
auto startCore = [&opts]() {
|
|
auto daemon = opts.useLibUV ? create_libuv_tcpstore_backend(opts)
|
|
: create_tcpstore_backend(opts);
|
|
daemon->start();
|
|
return std::make_shared<TCPServer>(daemon->port(), std::move(daemon));
|
|
};
|
|
|
|
std::shared_ptr<TCPServer> server{};
|
|
|
|
if (opts.multiTenant) {
|
|
std::lock_guard<std::mutex> guard{cache_mutex_};
|
|
|
|
// If the caller is okay with a multi-tenant store, first check if we
|
|
// already have a TCPServer running on the specified port.
|
|
if (opts.port > 0) {
|
|
auto pos = cachedServers_.find(opts.port);
|
|
if (pos != cachedServers_.end()) {
|
|
server = pos->second.lock();
|
|
if (server != nullptr) {
|
|
return server;
|
|
}
|
|
|
|
// Looks like the TCPStore has been disposed, make sure that we release
|
|
// the control block.
|
|
cachedServers_.erase(pos);
|
|
}
|
|
}
|
|
|
|
server = startCore();
|
|
|
|
cachedServers_.emplace(server->port(), server);
|
|
} else {
|
|
server = startCore();
|
|
}
|
|
|
|
return server;
|
|
}
|
|
|
|
class TCPClient {
|
|
public:
|
|
static std::unique_ptr<TCPClient> connect(
|
|
const SocketAddress& addr,
|
|
const TCPStoreOptions& opts,
|
|
std::shared_ptr<Backoff> backoff);
|
|
|
|
void sendRaw(uint8_t* data, size_t length) {
|
|
try {
|
|
tcputil::sendBytes(socket_.handle(), data, length);
|
|
} catch (const std::exception& e) {
|
|
C10D_WARNING("sendBytes failed on {}: {}", socket_.repr(), e.what());
|
|
throw;
|
|
}
|
|
}
|
|
|
|
std::vector<std::uint8_t> receiveBits() {
|
|
try {
|
|
return tcputil::recvVector<std::uint8_t>(socket_.handle());
|
|
} catch (const std::exception& e) {
|
|
C10D_WARNING("recvVector failed on {}: {}", socket_.repr(), e.what());
|
|
throw;
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
T receiveValue() {
|
|
try {
|
|
return tcputil::recvValue<T>(socket_.handle());
|
|
} catch (const std::exception& e) {
|
|
C10D_WARNING("recvValue failed on {}: {}", socket_.repr(), e.what());
|
|
throw;
|
|
}
|
|
}
|
|
template <typename T>
|
|
bool receiveValueWithTimeout(T& t, std::chrono::milliseconds timeout) {
|
|
if (!socket_.waitForInput(timeout))
|
|
return false;
|
|
t = tcputil::recvValue<T>(socket_.handle());
|
|
return true;
|
|
}
|
|
void setTimeout(std::chrono::milliseconds value);
|
|
|
|
explicit TCPClient(Socket&& socket) : socket_{std::move(socket)} {}
|
|
|
|
std::string repr() const {
|
|
return fmt::format("TCPClient({})", socket_.repr());
|
|
}
|
|
|
|
private:
|
|
Socket socket_;
|
|
};
|
|
|
|
std::unique_ptr<TCPClient> TCPClient::connect(
|
|
const SocketAddress& addr,
|
|
const TCPStoreOptions& opts,
|
|
std::shared_ptr<Backoff> backoff) {
|
|
Socket socket = Socket::connect(
|
|
addr.host,
|
|
addr.port,
|
|
SocketOptions{}
|
|
.connect_timeout(opts.timeout)
|
|
.connect_backoff(std::move(backoff)));
|
|
|
|
return std::make_unique<TCPClient>(std::move(socket));
|
|
}
|
|
|
|
void TCPClient::setTimeout(std::chrono::milliseconds value) {
|
|
if (value == std::chrono::milliseconds::zero()) {
|
|
return;
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
struct timeval timeoutTV = {
|
|
static_cast<long>(value.count() / 1000),
|
|
static_cast<long>((value.count() % 1000) * 1000)};
|
|
#else
|
|
struct timeval timeoutTV = {
|
|
.tv_sec = value.count() / 1000,
|
|
.tv_usec = static_cast<suseconds_t>((value.count() % 1000) * 1000),
|
|
};
|
|
#endif
|
|
SYSCHECK_ERR_RETURN_NEG1(::setsockopt(
|
|
socket_.handle(),
|
|
SOL_SOCKET,
|
|
SO_RCVTIMEO,
|
|
reinterpret_cast<char*>(&timeoutTV),
|
|
sizeof(timeoutTV)));
|
|
}
|
|
|
|
class SendBuffer {
|
|
// ethernet mtu 1500 - 40 (ip v6 header) - 20 (tcp header)
|
|
const size_t FLUSH_WATERMARK = 1440;
|
|
std::vector<uint8_t> buffer;
|
|
detail::TCPClient& client;
|
|
|
|
void maybeFlush() {
|
|
if (buffer.size() >= FLUSH_WATERMARK) {
|
|
flush();
|
|
}
|
|
}
|
|
|
|
public:
|
|
SendBuffer(detail::TCPClient& client, detail::QueryType cmd)
|
|
: client(client) {
|
|
buffer.reserve(32); // enough for most commands
|
|
buffer.push_back((uint8_t)cmd);
|
|
}
|
|
|
|
void appendString(const std::string& str) {
|
|
appendValue<uint64_t>(str.size());
|
|
buffer.insert(buffer.end(), str.begin(), str.end());
|
|
maybeFlush();
|
|
}
|
|
|
|
void appendBytes(const std::vector<uint8_t>& vec) {
|
|
appendValue<uint64_t>(vec.size());
|
|
buffer.insert(buffer.end(), vec.begin(), vec.end());
|
|
maybeFlush();
|
|
}
|
|
|
|
template <typename T>
|
|
void appendValue(T value) {
|
|
uint8_t* begin = (uint8_t*)&value;
|
|
buffer.insert(buffer.end(), begin, begin + sizeof(T));
|
|
maybeFlush();
|
|
}
|
|
|
|
void flush() {
|
|
if (!buffer.empty()) {
|
|
client.sendRaw(buffer.data(), buffer.size());
|
|
buffer.clear();
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace detail
|
|
|
|
using detail::Socket;
|
|
|
|
// TCPStore class methods
|
|
|
|
// Although we still allow multi-params in ctor in Python, that behavior is
|
|
// removed from cpp and we construct the opts implicitly for users in the pybind
|
|
// of TCPStore.
|
|
TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts)
|
|
: Store{opts.timeout},
|
|
addr_{std::move(host)},
|
|
numWorkers_{opts.numWorkers},
|
|
usingLibUv_{opts.useLibUV} {
|
|
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__init);
|
|
|
|
if (opts.useLibUV) {
|
|
TORCH_CHECK(
|
|
::c10d::detail::is_libuv_tcpstore_backend_available(),
|
|
"use_libuv was requested but PyTorch was build without libuv support");
|
|
|
|
if (opts.masterListenFd.has_value()) {
|
|
// TODO(xilunwu): support this init method after testing
|
|
constexpr auto* msg =
|
|
"The libuv TCPStore backend does not support initialization with an listen fd. "
|
|
"Please switch to the legacy TCPStore by setting environment variable USE_LIBUV "
|
|
"to \"0\".";
|
|
C10D_ERROR(msg);
|
|
C10_THROW_ERROR(NotImplementedError, msg);
|
|
return;
|
|
}
|
|
}
|
|
|
|
Socket::initialize();
|
|
|
|
if (opts.isServer) {
|
|
server_ = detail::TCPServer::start(opts);
|
|
// server successfully started
|
|
C10D_DEBUG("The server has started on port = {}.", server_->port());
|
|
|
|
std::ifstream maxconnFile("/proc/sys/net/core/somaxconn");
|
|
if (maxconnFile.good() && numWorkers_.has_value()) {
|
|
try {
|
|
std::string str(
|
|
(std::istreambuf_iterator<char>(maxconnFile)),
|
|
std::istreambuf_iterator<char>());
|
|
std::size_t somaxconn = std::stoll(str);
|
|
if (somaxconn < *numWorkers_) {
|
|
C10D_WARNING(
|
|
"Starting store with {} workers but somaxconn is {}."
|
|
"This might cause instability during bootstrap, consider increasing it.",
|
|
*numWorkers_,
|
|
somaxconn);
|
|
}
|
|
} catch (std::logic_error& e) {
|
|
C10D_INFO("failed to parse somaxconn proc file due to {}", e.what());
|
|
}
|
|
}
|
|
|
|
addr_.port = server_->port();
|
|
} else {
|
|
addr_.port = opts.port;
|
|
}
|
|
|
|
// Try connecting several times -- if the server listen backlog is full it may
|
|
// fail on the first send in validate.
|
|
auto deadline = std::chrono::steady_clock::now() + opts.timeout;
|
|
auto backoff = std::make_shared<ExponentialBackoffWithJitter>();
|
|
|
|
auto retry = 0;
|
|
do {
|
|
try {
|
|
client_ = detail::TCPClient::connect(addr_, opts, backoff);
|
|
// TCP connection established
|
|
C10D_DEBUG("TCP client connected to host {}:{}", addr_.host, addr_.port);
|
|
|
|
// client's first query for validation
|
|
validate();
|
|
|
|
// ping to verify network connectivity
|
|
ping();
|
|
|
|
// success
|
|
break;
|
|
} catch (const c10::DistNetworkError& ex) {
|
|
if (deadline < std::chrono::steady_clock::now()) {
|
|
C10D_ERROR(
|
|
"TCP client failed to connect/validate to host {}:{} - timed out (try={}, timeout={}ms): {}",
|
|
addr_.host,
|
|
addr_.port,
|
|
retry,
|
|
opts.timeout.count(),
|
|
ex.what());
|
|
throw;
|
|
}
|
|
|
|
auto delayDuration = backoff->nextBackoff();
|
|
|
|
C10D_WARNING(
|
|
"TCP client failed to connect/validate to host {}:{} - retrying (try={}, timeout={}ms, delay={}ms): {}",
|
|
addr_.host,
|
|
addr_.port,
|
|
retry,
|
|
opts.timeout.count(),
|
|
delayDuration.count(),
|
|
ex.what());
|
|
|
|
std::this_thread::sleep_for(delayDuration);
|
|
retry += 1;
|
|
}
|
|
} while (true);
|
|
|
|
if (opts.waitWorkers) {
|
|
waitForWorkers();
|
|
}
|
|
}
|
|
|
|
TCPStore::~TCPStore() = default;
|
|
|
|
void TCPStore::waitForWorkers() {
|
|
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__waitForWorkers);
|
|
if (numWorkers_ == std::nullopt) {
|
|
return;
|
|
}
|
|
|
|
incrementValueBy(initKey_, 1);
|
|
|
|
// Let server block until all workers have completed, this ensures that
|
|
// the server daemon thread is always running until the very end
|
|
if (server_) {
|
|
const auto start = std::chrono::steady_clock::now();
|
|
while (true) {
|
|
// TODO: Any chance to make this cleaner?
|
|
std::vector<uint8_t> value = doGet(initKey_);
|
|
auto buf = reinterpret_cast<const char*>(value.data());
|
|
auto len = value.size();
|
|
int numWorkersCompleted = std::stoi(std::string(buf, len));
|
|
if (numWorkersCompleted >= static_cast<int>(*numWorkers_)) {
|
|
break;
|
|
}
|
|
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
|
std::chrono::steady_clock::now() - start);
|
|
if (timeout_ != kNoTimeout && elapsed > timeout_) {
|
|
C10_THROW_ERROR(
|
|
DistStoreError,
|
|
fmt::format(
|
|
"Timed out after {} seconds waiting for clients. {}/{} clients joined.",
|
|
elapsed.count(),
|
|
numWorkersCompleted,
|
|
*numWorkers_));
|
|
}
|
|
/* sleep override */
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
}
|
|
}
|
|
}
|
|
|
|
void TCPStore::validate() {
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
detail::SendBuffer buffer(*client_, detail::QueryType::VALIDATE);
|
|
buffer.appendValue<std::uint32_t>(c10d::detail::validationMagicNumber);
|
|
buffer.flush();
|
|
}
|
|
|
|
void TCPStore::ping() {
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
detail::SendBuffer buffer(*client_, detail::QueryType::PING);
|
|
|
|
uint32_t nonce = getpid();
|
|
buffer.appendValue<std::uint32_t>(nonce);
|
|
buffer.flush();
|
|
|
|
uint32_t returnedNonce = client_->receiveValue<std::uint32_t>();
|
|
TORCH_INTERNAL_ASSERT(
|
|
nonce == returnedNonce, "Ping failed, invalid nonce returned");
|
|
}
|
|
|
|
void TCPStore::_splitSet(
|
|
const std::string& key,
|
|
const std::vector<uint8_t>& data) {
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
detail::SendBuffer buffer(*client_, detail::QueryType::SET);
|
|
buffer.appendString(keyPrefix_ + key);
|
|
buffer.flush();
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
|
buffer.appendBytes(data);
|
|
buffer.flush();
|
|
}
|
|
|
|
void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
|
|
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__set);
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
detail::SendBuffer buffer(*client_, detail::QueryType::SET);
|
|
buffer.appendString(keyPrefix_ + key);
|
|
buffer.appendBytes(data);
|
|
buffer.flush();
|
|
}
|
|
|
|
std::vector<uint8_t> TCPStore::compareSet(
|
|
const std::string& key,
|
|
const std::vector<uint8_t>& expectedValue,
|
|
const std::vector<uint8_t>& desiredValue) {
|
|
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__compareSet);
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
detail::SendBuffer buffer(*client_, detail::QueryType::COMPARE_SET);
|
|
buffer.appendString(keyPrefix_ + key);
|
|
buffer.appendBytes(expectedValue);
|
|
buffer.appendBytes(desiredValue);
|
|
buffer.flush();
|
|
|
|
return client_->receiveBits();
|
|
}
|
|
|
|
std::vector<uint8_t> TCPStore::get(const std::string& key) {
|
|
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__get);
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
return doGet(keyPrefix_ + key);
|
|
}
|
|
|
|
std::vector<uint8_t> TCPStore::doGet(const std::string& key) {
|
|
doWait(key, timeout_);
|
|
detail::SendBuffer buffer(*client_, detail::QueryType::GET);
|
|
buffer.appendString(key);
|
|
buffer.flush();
|
|
|
|
return client_->receiveBits();
|
|
}
|
|
|
|
int64_t TCPStore::add(const std::string& key, int64_t value) {
|
|
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__add);
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
return incrementValueBy(keyPrefix_ + key, value);
|
|
}
|
|
|
|
bool TCPStore::deleteKey(const std::string& key) {
|
|
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__delete);
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
detail::SendBuffer buffer(*client_, detail::QueryType::DELETE_KEY);
|
|
buffer.appendString(keyPrefix_ + key);
|
|
buffer.flush();
|
|
|
|
auto numDeleted = client_->receiveValue<std::int64_t>();
|
|
return numDeleted == 1;
|
|
}
|
|
|
|
int64_t TCPStore::incrementValueBy(const std::string& key, int64_t delta) {
|
|
detail::SendBuffer buff(*client_, detail::QueryType::ADD);
|
|
buff.appendString(key);
|
|
buff.appendValue<std::int64_t>(delta);
|
|
buff.flush();
|
|
|
|
return client_->receiveValue<std::int64_t>();
|
|
}
|
|
|
|
int64_t TCPStore::getNumKeys() {
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
detail::SendBuffer buffer(*client_, detail::QueryType::GETNUMKEYS);
|
|
buffer.flush();
|
|
|
|
return client_->receiveValue<std::int64_t>();
|
|
}
|
|
|
|
bool TCPStore::check(const std::vector<std::string>& keys) {
|
|
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__check);
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
detail::SendBuffer buffer(*client_, detail::QueryType::CHECK);
|
|
buffer.appendValue(keys.size());
|
|
|
|
for (const std::string& key : keys) {
|
|
buffer.appendString(keyPrefix_ + key);
|
|
}
|
|
buffer.flush();
|
|
|
|
auto response = client_->receiveValue<detail::CheckResponseType>();
|
|
if (response == detail::CheckResponseType::READY) {
|
|
return true;
|
|
}
|
|
if (response == detail::CheckResponseType::NOT_READY) {
|
|
return false;
|
|
}
|
|
TORCH_CHECK(false, "ready or not_ready response expected");
|
|
}
|
|
|
|
void TCPStore::wait(const std::vector<std::string>& keys) {
|
|
wait(keys, timeout_);
|
|
}
|
|
|
|
void TCPStore::wait(
|
|
const std::vector<std::string>& keys,
|
|
const std::chrono::milliseconds& timeout) {
|
|
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__wait);
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
std::vector<std::string> prefixedKeys{};
|
|
prefixedKeys.reserve(keys.size());
|
|
for (const std::string& key : keys) {
|
|
prefixedKeys.emplace_back(keyPrefix_ + key);
|
|
}
|
|
|
|
doWait(prefixedKeys, timeout);
|
|
}
|
|
|
|
void TCPStore::doWait(
|
|
c10::ArrayRef<std::string> keys,
|
|
std::chrono::milliseconds timeout) {
|
|
{
|
|
detail::SendBuffer buffer(*client_, detail::QueryType::WAIT);
|
|
buffer.appendValue(keys.size());
|
|
for (const std::string& key : keys) {
|
|
buffer.appendString(key);
|
|
}
|
|
buffer.flush();
|
|
}
|
|
|
|
detail::WaitResponseType response;
|
|
if (client_->receiveValueWithTimeout<detail::WaitResponseType>(
|
|
response, timeout)) {
|
|
if (response != detail::WaitResponseType::STOP_WAITING) {
|
|
TORCH_CHECK(false, "Stop_waiting response is expected");
|
|
}
|
|
return;
|
|
}
|
|
// this is the cancel wait timeout, once here we expect the server to respond
|
|
// in a timely fashion
|
|
{
|
|
detail::SendBuffer buffer(*client_, detail::QueryType::CANCEL_WAIT);
|
|
buffer.flush();
|
|
}
|
|
|
|
response = client_->receiveValue<detail::WaitResponseType>();
|
|
// this can happen if the server responds before we cancel, just ignore it
|
|
if (response != detail::WaitResponseType::WAIT_CANCELED) {
|
|
if (response != detail::WaitResponseType::STOP_WAITING) {
|
|
TORCH_CHECK(false, "Stop_waiting response is expected");
|
|
}
|
|
|
|
response = client_->receiveValue<detail::WaitResponseType>(); // ignore
|
|
if (response != detail::WaitResponseType::WAIT_CANCELED) {
|
|
TORCH_CHECK(false, "wait_canceled response is expected");
|
|
}
|
|
}
|
|
C10_THROW_ERROR(
|
|
DistStoreError,
|
|
fmt::format(
|
|
"wait timeout after {}ms, keys: {}",
|
|
timeout.count(),
|
|
fmt::join(keys, ", ")));
|
|
}
|
|
|
|
void TCPStore::append(
|
|
const std::string& key,
|
|
const std::vector<uint8_t>& data) {
|
|
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__append);
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
detail::SendBuffer buffer(*client_, detail::QueryType::APPEND);
|
|
buffer.appendString(keyPrefix_ + key);
|
|
buffer.appendBytes(data);
|
|
buffer.flush();
|
|
}
|
|
|
|
std::vector<std::vector<uint8_t>> TCPStore::multiGet(
|
|
const std::vector<std::string>& keys) {
|
|
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__multiGet);
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
std::vector<std::string> prefixedKeys;
|
|
prefixedKeys.reserve(keys.size());
|
|
for (const std::string& key : keys) {
|
|
prefixedKeys.emplace_back(keyPrefix_ + key);
|
|
}
|
|
doWait(prefixedKeys, timeout_);
|
|
|
|
detail::SendBuffer buffer(*client_, detail::QueryType::MULTI_GET);
|
|
buffer.appendValue(keys.size());
|
|
for (auto& key : prefixedKeys) {
|
|
buffer.appendString(key);
|
|
}
|
|
buffer.flush();
|
|
|
|
std::vector<std::vector<uint8_t>> result;
|
|
result.reserve(keys.size());
|
|
for (size_t i = 0; i < keys.size(); ++i) {
|
|
result.emplace_back(client_->receiveBits());
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void TCPStore::multiSet(
|
|
const std::vector<std::string>& keys,
|
|
const std::vector<std::vector<uint8_t>>& values) {
|
|
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__multiSet);
|
|
TORCH_CHECK(
|
|
keys.size() == values.size(),
|
|
"multiSet keys and values vectors must be of same size");
|
|
const std::lock_guard<std::mutex> lock(activeOpLock_);
|
|
|
|
detail::SendBuffer buffer(*client_, detail::QueryType::MULTI_SET);
|
|
buffer.appendValue<std::int64_t>(keys.size());
|
|
for (auto i : c10::irange(keys.size())) {
|
|
buffer.appendString(keyPrefix_ + keys[i]);
|
|
buffer.appendBytes(values[i]);
|
|
}
|
|
buffer.flush();
|
|
}
|
|
|
|
bool TCPStore::hasExtendedApi() const {
|
|
return true;
|
|
}
|
|
|
|
std::string TCPStore::repr() const {
|
|
auto clientRepr = client_ ? client_->repr() : "<nullptr>";
|
|
auto serverRepr = server_ ? server_->repr() : "<nullptr>";
|
|
return fmt::format("TCPStore(client={}, server={})", clientRepr, serverRepr);
|
|
}
|
|
|
|
} // namespace c10d
|