mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This series of changes try to cover C style casts into C++ alternatives. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165750 Approved by: https://github.com/Skylion007
1483 lines
53 KiB
C++
1483 lines
53 KiB
C++
#include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
|
|
|
|
#ifdef USE_TENSORPIPE
|
|
|
|
#include <limits>
|
|
#include <tuple>
|
|
#include <utility>
|
|
|
|
#include <fmt/format.h>
|
|
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated")
|
|
#include <tensorpipe/tensorpipe.h>
|
|
C10_DIAGNOSTIC_POP()
|
|
|
|
#include <torch/csrc/distributed/rpc/agent_utils.h>
|
|
#include <torch/csrc/distributed/rpc/tensorpipe_utils.h>
|
|
#include <torch/csrc/distributed/rpc/utils.h>
|
|
|
|
#include <c10/core/StreamGuard.h>
|
|
#include <c10/util/irange.h>
|
|
|
|
namespace torch::distributed::rpc {
|
|
|
|
namespace {
|
|
|
|
// An environment variable along the lines of GLOO_ and NCCL_SOCKET_IFNAME that
|
|
// allows the user to specify a device to bind to, instead of binding to the
|
|
// address that the hostname resolves to.
|
|
const std::string kSocketIfnameEnvVar = "TP_SOCKET_IFNAME";
|
|
const std::string kDefaultUvAddress = "127.0.0.1";
|
|
|
|
const std::string kGilAverageWaitTime = "agent.gil_average_wait_time_us";
|
|
const std::string kThreadPoolSize = "agent.thread_pool_size";
|
|
const std::string kNumIdleThreads = "agent.num_idle_threads";
|
|
const std::string kClientActiveCalls = "agent.client_active_calls";
|
|
const std::string kServerActiveCalls = "agent.server_active_calls";
|
|
const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
|
|
|
|
std::vector<c10::Device> getDevicesForTensors(
|
|
const std::vector<torch::Tensor>& tensors,
|
|
const DeviceMap& deviceMap,
|
|
const std::string& remoteName) {
|
|
// If the deviceMap is overridden, use that instead.
|
|
const auto errStr = c10::str(
|
|
"TensorPipe RPC backend only supports CPU tensors by default, please "
|
|
"move your tensors to CPU before sending them over RPC, or call "
|
|
"`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly "
|
|
"configure device mapping. ",
|
|
"Request device mapping is not available for destination ",
|
|
remoteName);
|
|
std::vector<c10::Device> devices;
|
|
devices.reserve(tensors.size());
|
|
bool hasMappedDevice = false;
|
|
for (const auto& t : tensors) {
|
|
if (t.device().is_cpu()) {
|
|
const auto deviceIter = deviceMap.find(c10::kCPU);
|
|
if (deviceIter == deviceMap.end()) {
|
|
devices.emplace_back(c10::kCPU);
|
|
} else {
|
|
devices.emplace_back(deviceIter->second);
|
|
hasMappedDevice = true;
|
|
}
|
|
} else {
|
|
const auto deviceIter = deviceMap.find(t.device());
|
|
TORCH_CHECK(
|
|
deviceIter != deviceMap.end(),
|
|
errStr,
|
|
" for device ",
|
|
t.device(),
|
|
" but received a tensor on that device.");
|
|
devices.push_back(deviceIter->second);
|
|
hasMappedDevice = true;
|
|
}
|
|
}
|
|
if (!hasMappedDevice) {
|
|
devices.clear();
|
|
}
|
|
return devices;
|
|
}
|
|
|
|
std::vector<c10::Stream> getStreamsFromPoolForDevices(
|
|
const std::vector<c10::Device>& devices) {
|
|
if (devices.empty()) {
|
|
return {};
|
|
}
|
|
c10::impl::VirtualGuardImpl impl(devices[0].type());
|
|
std::vector<c10::Stream> streams;
|
|
streams.reserve(devices.size());
|
|
for (const c10::Device& device : devices) {
|
|
TORCH_INTERNAL_ASSERT(device.type() == impl.type());
|
|
streams.push_back(impl.getStreamFromGlobalPool(device));
|
|
}
|
|
return streams;
|
|
}
|
|
|
|
std::vector<c10::Stream> getCurrentStreamsForDevices(
|
|
const std::vector<c10::Device>& devices) {
|
|
if (devices.empty()) {
|
|
return {};
|
|
}
|
|
c10::impl::VirtualGuardImpl impl(devices[0].type());
|
|
std::vector<c10::Stream> streams;
|
|
streams.reserve(devices.size());
|
|
for (const c10::Device& device : devices) {
|
|
TORCH_INTERNAL_ASSERT(device.type() == impl.type());
|
|
streams.push_back(impl.getStream(device));
|
|
}
|
|
return streams;
|
|
}
|
|
|
|
std::vector<c10::Device> getDevicesOfTensors(
|
|
const std::vector<torch::Tensor>& tensors) {
|
|
std::optional<c10::impl::VirtualGuardImpl> impl;
|
|
size_t deviceCount = 0;
|
|
std::vector<bool> indexBitset;
|
|
for (const torch::Tensor& tensor : tensors) {
|
|
if (!tensor.is_cpu()) {
|
|
c10::Device device = tensor.device();
|
|
if (!impl.has_value()) {
|
|
impl.emplace(device.type());
|
|
indexBitset.resize(impl->deviceCount());
|
|
}
|
|
TORCH_INTERNAL_ASSERT(device.type() == impl->type());
|
|
TORCH_INTERNAL_ASSERT(device.has_index());
|
|
if (!indexBitset[device.index()]) {
|
|
deviceCount++;
|
|
indexBitset[device.index()] = true;
|
|
}
|
|
}
|
|
}
|
|
std::vector<c10::Device> devices;
|
|
devices.reserve(deviceCount);
|
|
for (const auto idx : c10::irange(indexBitset.size())) {
|
|
if (indexBitset[idx]) {
|
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
|
devices.emplace_back(impl->type(), static_cast<c10::DeviceIndex>(idx));
|
|
}
|
|
}
|
|
return devices;
|
|
}
|
|
|
|
void makeStreamsWaitOnOthers(
|
|
const std::vector<c10::Stream>& consumers,
|
|
const std::vector<c10::Stream>& producers) {
|
|
for (const c10::Stream& producer : producers) {
|
|
const c10::Stream& consumer =
|
|
getStreamForDevice(consumers, producer.device());
|
|
c10::Event event(producer.device_type());
|
|
event.record(producer);
|
|
event.block(consumer);
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
C10_DEFINE_REGISTRY_WITHOUT_WARNING(
|
|
TensorPipeTransportRegistry,
|
|
TransportRegistration)
|
|
|
|
C10_DEFINE_REGISTRY_WITHOUT_WARNING(
|
|
TensorPipeChannelRegistry,
|
|
ChannelRegistration)
|
|
|
|
const std::string& TensorPipeAgent::guessAddress() {
|
|
static const std::string uvAddress = []() {
|
|
auto ifnameEnv = c10::utils::get_env(kSocketIfnameEnvVar.c_str());
|
|
if (ifnameEnv.has_value()) {
|
|
auto [error, result] =
|
|
tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv.value());
|
|
if (error) {
|
|
LOG(WARNING) << "Failed to look up the IP address for interface "
|
|
<< ifnameEnv.value() << " (" << error.what()
|
|
<< "), defaulting to " << kDefaultUvAddress;
|
|
return kDefaultUvAddress;
|
|
}
|
|
return result;
|
|
}
|
|
auto [error, result] = tensorpipe::transport::uv::lookupAddrForHostname();
|
|
if (error) {
|
|
LOG(WARNING) << "Failed to look up the IP address for the hostname ("
|
|
<< error.what() << "), defaulting to " << kDefaultUvAddress;
|
|
return kDefaultUvAddress;
|
|
}
|
|
return result;
|
|
}();
|
|
return uvAddress;
|
|
}
|
|
|
|
namespace {
|
|
|
|
std::unique_ptr<TransportRegistration> makeUvTransport() {
|
|
auto context = tensorpipe::transport::uv::create();
|
|
std::string address = TensorPipeAgent::guessAddress();
|
|
return std::make_unique<TransportRegistration>(TransportRegistration{
|
|
std::move(context), kUvTransportPriority, std::move(address)});
|
|
}
|
|
|
|
// The UV transport is implemented using standard TCP connections. It leverages
|
|
// libuv (https://github.com/libuv/libuv) in order to be cross-platform.
|
|
C10_REGISTER_CREATOR(TensorPipeTransportRegistry, uv, makeUvTransport)
|
|
|
|
#if TENSORPIPE_HAS_SHM_TRANSPORT
|
|
|
|
std::unique_ptr<TransportRegistration> makeShmTransport() {
|
|
auto context = tensorpipe::transport::shm::create();
|
|
return std::make_unique<TransportRegistration>(
|
|
TransportRegistration{std::move(context), kShmTransportPriority, ""});
|
|
}
|
|
|
|
// The SHM implements connections using ringbuffers residing in anonymous shared
|
|
// memory (plus UNIX domain sockets to bootstrap the connection and exchange
|
|
// file descriptors). It is Linux-only due to some advanced features (O_TMPFILE,
|
|
// eventfd, ...).
|
|
C10_REGISTER_CREATOR(TensorPipeTransportRegistry, shm, makeShmTransport)
|
|
|
|
#endif // TENSORPIPE_HAS_SHM_TRANSPORT
|
|
|
|
#if TENSORPIPE_HAS_IBV_TRANSPORT
|
|
|
|
std::unique_ptr<TransportRegistration> makeIbvTransport() {
|
|
auto context = tensorpipe::transport::ibv::create();
|
|
std::string address = TensorPipeAgent::guessAddress();
|
|
return std::make_unique<TransportRegistration>(TransportRegistration{
|
|
std::move(context), kIbvTransportPriority, std::move(address)});
|
|
}
|
|
|
|
// The IBV transport sends data across using an InfiniBand queue pair, locally
|
|
// copying data to and from a staging buffer (registered with libibverbs) and
|
|
// issuing a RDMA write for transferring data across machines (plus a send for
|
|
// acknowledging it). It bootstraps using a standard TCP connection to exchange
|
|
// setup information. It is Linux-only.
|
|
C10_REGISTER_CREATOR(TensorPipeTransportRegistry, ibv, makeIbvTransport)
|
|
|
|
#endif // TENSORPIPE_HAS_IBV_TRANSPORT
|
|
|
|
std::unique_ptr<ChannelRegistration> makeBasicChannel() {
|
|
auto context = tensorpipe::channel::basic::create();
|
|
return std::make_unique<ChannelRegistration>(
|
|
ChannelRegistration{std::move(context), kBasicChannelPriority});
|
|
}
|
|
|
|
// The basic channel is just a straightforward adapter wrapper that allows any
|
|
// transport to be used as a channel.
|
|
C10_REGISTER_CREATOR(TensorPipeChannelRegistry, basic, makeBasicChannel)
|
|
|
|
#if TENSORPIPE_HAS_CMA_CHANNEL
|
|
|
|
std::unique_ptr<ChannelRegistration> makeCmaChannel() {
|
|
auto context = tensorpipe::channel::cma::create();
|
|
return std::make_unique<ChannelRegistration>(
|
|
ChannelRegistration{std::move(context), kCmaChannelPriority});
|
|
}
|
|
|
|
// The CMA channel uses the Linux cross-memory attach syscalls (process_vm_readv
|
|
// and _writev), which allow one process to access the private memory of another
|
|
// process (as long as they belong to the same user and other security
|
|
// constraints are satisfied). It does, more or less, what GDB does when it's
|
|
// attached to a running process.
|
|
C10_REGISTER_CREATOR(TensorPipeChannelRegistry, cma, makeCmaChannel)
|
|
|
|
#endif // TENSORPIPE_HAS_CMA_CHANNEL
|
|
|
|
constexpr static int kNumUvThreads = 16;
|
|
|
|
std::unique_ptr<ChannelRegistration> makeMultiplexedUvChannel() {
|
|
std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
|
|
contexts.reserve(kNumUvThreads);
|
|
std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
|
|
listeners.reserve(kNumUvThreads);
|
|
for ([[maybe_unused]] const auto laneIdx : c10::irange(kNumUvThreads)) {
|
|
auto context = tensorpipe::transport::uv::create();
|
|
const std::string& address = TensorPipeAgent::guessAddress();
|
|
contexts.push_back(std::move(context));
|
|
listeners.push_back(contexts.back()->listen(address));
|
|
}
|
|
auto context = tensorpipe::channel::mpt::create(
|
|
std::move(contexts), std::move(listeners));
|
|
return std::make_unique<ChannelRegistration>(
|
|
ChannelRegistration{std::move(context), kMultiplexedUvChannelPriority});
|
|
}
|
|
|
|
// The multiplexed UV channel encapsulates multiple UV transports (each with its
|
|
// own event loop thread). Each channel will, in turn, contain multiple UV
|
|
// connections, one for each of those contexts. When sending a tensor, its data
|
|
// is split in equal chunks and each chunks is sent on a different connection
|
|
// and thus driven by a different thread. This is needed to reach very high
|
|
// bandwidths.
|
|
C10_REGISTER_CREATOR(
|
|
TensorPipeChannelRegistry,
|
|
mpt_uv,
|
|
makeMultiplexedUvChannel)
|
|
|
|
} // namespace
|
|
|
|
////////////////////////// MetricsTracker /////////////////////////////////
|
|
|
|
TensorPipeAgent::TimeSeriesMetricsTracker::TimeSeriesMetricsTracker(
|
|
uint64_t currentSum,
|
|
uint64_t currentCount)
|
|
: currentSum_(currentSum), currentCount_(currentCount) {}
|
|
|
|
void TensorPipeAgent::TimeSeriesMetricsTracker::addData(uint64_t dataPoint) {
|
|
currentSum_ += dataPoint;
|
|
++currentCount_;
|
|
}
|
|
|
|
float TensorPipeAgent::TimeSeriesMetricsTracker::computeAverage() const {
|
|
return currentCount_ == 0 ? 0
|
|
: static_cast<float>(
|
|
static_cast<double>(currentSum_) /
|
|
static_cast<double>(currentCount_));
|
|
}
|
|
|
|
//////////////////////// TensorpipeRpcAgent /////////////////////////////////
|
|
|
|
void TensorPipeAgent::removeFromTimeoutMap(uint64_t messageId) {
|
|
// Remove entry from timeoutMap_.
|
|
{
|
|
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
|
|
auto it = messageIdToTimeout_.find(messageId);
|
|
if (it == messageIdToTimeout_.end()) {
|
|
// Already removed from the map by pollTimeoutRpcs(), no need to
|
|
// process further.
|
|
return;
|
|
}
|
|
|
|
auto& expirationTime = it->second;
|
|
|
|
auto& timedOutFuturesVector = timeoutMap_[expirationTime];
|
|
for (auto it = timedOutFuturesVector.begin();
|
|
it != timedOutFuturesVector.end();
|
|
it++) {
|
|
if (it->messageId == messageId) {
|
|
it = timedOutFuturesVector.erase(it);
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (timedOutFuturesVector.empty()) {
|
|
timeoutMap_.erase(expirationTime);
|
|
}
|
|
|
|
// Remove from messageId to timeout map as well.
|
|
messageIdToTimeout_.erase(messageId);
|
|
}
|
|
}
|
|
|
|
void TensorPipeAgent::prepareNames(bool isStaticGroup) {
|
|
std::unordered_map<std::string, worker_id_t> nameToId;
|
|
if (isStaticGroup) {
|
|
nameToId = collectNames(
|
|
rankToNameStore_, workerInfo_.id_, workerInfo_.name_, worldSize_);
|
|
} else {
|
|
nameToId = collectCurrentNames(
|
|
rankToNameStore_, workerInfo_.id_, workerInfo_.name_);
|
|
}
|
|
|
|
for (const auto& entry : nameToId) {
|
|
const auto& workerName = entry.first;
|
|
const auto& workerId = entry.second;
|
|
workerIdToInfo_.emplace(workerId, WorkerInfo(workerName, workerId));
|
|
workerNameToInfo_.emplace(workerName, WorkerInfo(workerName, workerId));
|
|
}
|
|
}
|
|
|
|
void TensorPipeAgent::checkAndSetStaticGroup(
|
|
const c10::intrusive_ptr<::c10d::Store>& store) {
|
|
std::string isStaticGroupKey("rpcIsStaticGroup");
|
|
|
|
std::string isStaticGroupStr = isStaticGroup_ ? "true" : "false";
|
|
std::vector<uint8_t> isStaticGroupVec(
|
|
(uint8_t*)isStaticGroupStr.c_str(),
|
|
(uint8_t*)isStaticGroupStr.c_str() + isStaticGroupStr.length());
|
|
std::vector<uint8_t> returnedVec;
|
|
returnedVec = store->compareSet(
|
|
isStaticGroupKey, std::vector<uint8_t>(), isStaticGroupVec);
|
|
std::string returnedVal = std::string(returnedVec.begin(), returnedVec.end());
|
|
// In both cases, the returned value should be the value of isStaticGroupStr,
|
|
// otherwise there is a discrepancy with initialization among one of the
|
|
// members
|
|
TORCH_CHECK(
|
|
returnedVal == isStaticGroupStr,
|
|
fmt::format(
|
|
"RPC group mixes statically and dynamically initialized members which is not supported. ",
|
|
"Static group property is initialized as {} and is trying to be set as {} ",
|
|
isStaticGroup_,
|
|
returnedVal));
|
|
}
|
|
|
|
TensorPipeAgent::TensorPipeAgent(
|
|
const c10::intrusive_ptr<::c10d::Store>& store,
|
|
std::string selfName,
|
|
worker_id_t selfId,
|
|
std::optional<int> worldSize,
|
|
TensorPipeRpcBackendOptions opts,
|
|
std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
|
|
std::vector<c10::Device> devices,
|
|
std::unique_ptr<RequestCallback> cb)
|
|
: RpcAgent(
|
|
WorkerInfo(std::move(selfName), selfId),
|
|
std::move(cb),
|
|
std::chrono::milliseconds(
|
|
static_cast<long>(opts.rpcTimeoutSeconds * kSecToMsConversion))),
|
|
isStaticGroup_(worldSize.has_value()),
|
|
store_(store),
|
|
opts_(std::move(opts)),
|
|
reverseDeviceMaps_(std::move(reverseDeviceMaps)),
|
|
devices_(std::move(devices)),
|
|
threadPool_(opts_.numWorkerThreads),
|
|
context_(std::make_shared<tensorpipe::Context>(
|
|
tensorpipe::ContextOptions().name(workerInfo_.name_))),
|
|
rankToNameStore_("names", store),
|
|
nameToAddressStore_("addrs", store),
|
|
shutdownStore_("shutdown", store) {
|
|
if (isStaticGroup_) {
|
|
worldSize_ = worldSize.value();
|
|
}
|
|
|
|
// check the static group attribute against store
|
|
checkAndSetStaticGroup(store);
|
|
|
|
// collect worker names
|
|
prepareNames(isStaticGroup_);
|
|
|
|
// Initialize the time-series metrics tracking map
|
|
timeSeriesMetrics_.emplace(kGilAverageWaitTime, TimeSeriesMetricsTracker());
|
|
}
|
|
|
|
TensorPipeAgent::~TensorPipeAgent() {
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is being destroyed";
|
|
shutdown();
|
|
}
|
|
|
|
void TensorPipeAgent::startImpl() {
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is starting";
|
|
|
|
std::vector<std::string> addresses;
|
|
int64_t lowestPriority = std::numeric_limits<int64_t>::max();
|
|
std::string lowestPriorityTransport;
|
|
|
|
// Register transports
|
|
for (auto& key : TensorPipeTransportRegistry()->Keys()) {
|
|
int64_t priority = -1;
|
|
if (opts_.transports.has_value()) {
|
|
auto iter =
|
|
std::find(opts_.transports->begin(), opts_.transports->end(), key);
|
|
if (iter == opts_.transports->end()) {
|
|
continue;
|
|
}
|
|
// Assign priorities in reverse order of occurrence in the vector, so that
|
|
// a transport that comes before another receives a higher priority.
|
|
priority = static_cast<std::ptrdiff_t>(opts_.transports->size()) - 1 -
|
|
(iter - opts_.transports->begin());
|
|
}
|
|
std::unique_ptr<TransportRegistration> reg =
|
|
TensorPipeTransportRegistry()->Create(key);
|
|
if (!reg->transport->isViable()) {
|
|
continue;
|
|
}
|
|
if (priority == -1) {
|
|
priority = reg->priority;
|
|
}
|
|
if (priority < lowestPriority) {
|
|
lowestPriority = priority;
|
|
lowestPriorityTransport = key;
|
|
}
|
|
addresses.push_back(c10::str(key, "://", reg->address));
|
|
context_->registerTransport(priority, key, reg->transport);
|
|
}
|
|
|
|
// Register channels
|
|
for (auto& key : TensorPipeChannelRegistry()->Keys()) {
|
|
int64_t priority = -1;
|
|
if (opts_.channels.has_value()) {
|
|
auto iter =
|
|
std::find(opts_.channels->begin(), opts_.channels->end(), key);
|
|
if (iter == opts_.channels->end()) {
|
|
continue;
|
|
}
|
|
// Assign priorities in reverse order of occurrence in the vector, so
|
|
// that a channel that comes before another receives a higher priority.
|
|
priority = static_cast<std::ptrdiff_t>(opts_.channels->size()) - 1 -
|
|
(iter - opts_.channels->begin());
|
|
}
|
|
std::unique_ptr<ChannelRegistration> reg =
|
|
TensorPipeChannelRegistry()->Create(key);
|
|
if (!reg->channel->isViable()) {
|
|
continue;
|
|
}
|
|
if (priority == -1) {
|
|
priority = reg->priority;
|
|
}
|
|
context_->registerChannel(priority, key, reg->channel);
|
|
}
|
|
|
|
listener_ = context_->listen(addresses);
|
|
|
|
// Store our own url.
|
|
const auto address = listener_->url(lowestPriorityTransport);
|
|
nameToAddressStore_.set(workerInfo_.name_, address);
|
|
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is using address "
|
|
<< address;
|
|
|
|
for (const auto& p : workerNameToInfo_) {
|
|
const auto& name = p.first;
|
|
auto nodeAddrData = nameToAddressStore_.get(name);
|
|
auto nodeAddrStr = std::string(
|
|
reinterpret_cast<const char*>(nodeAddrData.data()),
|
|
nodeAddrData.size());
|
|
workerNameToURL_.insert({name, nodeAddrStr});
|
|
}
|
|
|
|
// Start the Timeout Thread
|
|
timeoutThread_ = std::thread(&TensorPipeAgent::pollTimeoutRpcs, this);
|
|
|
|
listener_->accept([this](
|
|
const tensorpipe::Error& error,
|
|
std::shared_ptr<tensorpipe::Pipe> pipe) {
|
|
onListenerAccepted(error, pipe);
|
|
});
|
|
}
|
|
|
|
void TensorPipeAgent::onListenerAccepted(
|
|
const tensorpipe::Error& error,
|
|
std::shared_ptr<tensorpipe::Pipe>& pipe) {
|
|
if (error) {
|
|
if (error.isOfType<tensorpipe::ListenerClosedError>() &&
|
|
!rpcAgentRunning_.load()) {
|
|
// This is expected.
|
|
} else {
|
|
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
|
|
<< " encountered error when accepting incoming pipe: "
|
|
<< error.what();
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Accept the next connection request
|
|
listener_->accept([this](
|
|
const tensorpipe::Error& error,
|
|
std::shared_ptr<tensorpipe::Pipe> pipe) {
|
|
onListenerAccepted(error, pipe);
|
|
});
|
|
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
|
<< " accepted incoming pipe from " << pipe->getRemoteName();
|
|
|
|
// Arm for server read
|
|
respond(pipe);
|
|
}
|
|
|
|
void TensorPipeAgent::pipeRead(
|
|
const std::shared_ptr<tensorpipe::Pipe>& pipe,
|
|
std::function<void(
|
|
const tensorpipe::Error&,
|
|
c10::intrusive_ptr<Message>,
|
|
std::vector<c10::Stream>)> fn) noexcept {
|
|
pipe->readDescriptor([this, fn{std::move(fn)}, pipe](
|
|
const tensorpipe::Error& error,
|
|
tensorpipe::Descriptor tpDescriptor) mutable {
|
|
if (error) {
|
|
fn(error, c10::intrusive_ptr<Message>(), {});
|
|
return;
|
|
}
|
|
|
|
std::vector<c10::Stream> streams;
|
|
{
|
|
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
|
|
streams = getStreamsFromPoolForDevices(devices_);
|
|
}
|
|
auto [tpAllocation, tpBuffers] = tensorpipeAllocate(tpDescriptor, streams);
|
|
|
|
pipe->read(
|
|
std::move(tpAllocation),
|
|
[tpDescriptor{std::move(tpDescriptor)},
|
|
tpBuffers{
|
|
std::make_shared<TensorpipeReadBuffers>(std::move(tpBuffers))},
|
|
fn{std::move(fn)},
|
|
streams{std::move(streams)}](const tensorpipe::Error& error) mutable {
|
|
if (error) {
|
|
fn(error, c10::intrusive_ptr<Message>(), {});
|
|
return;
|
|
}
|
|
|
|
// FIXME This does some unpickling, which could be a bit expensive:
|
|
// perhaps it would be best to perform it inside the worker threads?
|
|
c10::intrusive_ptr<Message> rpcMessage =
|
|
tensorpipeDeserialize(tpDescriptor, std::move(*tpBuffers));
|
|
|
|
fn(error, std::move(rpcMessage), std::move(streams));
|
|
});
|
|
});
|
|
}
|
|
|
|
void TensorPipeAgent::pipeWrite(
|
|
const std::shared_ptr<tensorpipe::Pipe>& pipe,
|
|
const c10::intrusive_ptr<Message>& rpcMessage,
|
|
std::vector<c10::Device>&& devices,
|
|
std::vector<c10::Stream> streams,
|
|
std::function<void(const tensorpipe::Error&)> fn) noexcept {
|
|
auto [tpMessage, tpBuffers] =
|
|
tensorpipeSerialize(rpcMessage, std::move(devices), streams);
|
|
|
|
pipe->write(
|
|
std::move(tpMessage),
|
|
[tpBuffers{
|
|
std::make_shared<TensorpipeWriteBuffers>(std::move(tpBuffers))},
|
|
fn{std::move(fn)},
|
|
streams{std::move(streams)}](const tensorpipe::Error& error) {
|
|
fn(error);
|
|
});
|
|
}
|
|
|
|
void TensorPipeAgent::sendCompletedResponseMessage(
|
|
std::shared_ptr<tensorpipe::Pipe>& pipe,
|
|
JitFuture& futureResponseMessage,
|
|
uint64_t messageId,
|
|
std::vector<c10::Stream> streams) {
|
|
if (!rpcAgentRunning_.load()) {
|
|
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
|
|
<< " won't send response to request #" << messageId << " to "
|
|
<< pipe->getRemoteName() << ", as the agent is shutting down";
|
|
return;
|
|
}
|
|
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
|
<< " is sending response to request #" << messageId << " to "
|
|
<< pipe->getRemoteName();
|
|
|
|
if (!futureResponseMessage.hasError()) {
|
|
c10::intrusive_ptr<Message> responseMessage =
|
|
futureResponseMessage.value().toCustomClass<Message>();
|
|
responseMessage->setId(static_cast<int64_t>(messageId));
|
|
|
|
std::vector<c10::Device> devices;
|
|
try {
|
|
devices = getDevicesForRemote(pipe->getRemoteName(), *responseMessage);
|
|
} catch (const std::exception& e) {
|
|
responseMessage =
|
|
createExceptionResponse(e.what(), static_cast<int64_t>(messageId));
|
|
}
|
|
|
|
for (const auto& tensor : responseMessage->tensors()) {
|
|
const auto device = tensor.device();
|
|
if (!device.is_cpu()) {
|
|
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
|
|
if (std::find(devices_.begin(), devices_.end(), device) ==
|
|
devices_.end()) {
|
|
std::ostringstream oss;
|
|
std::copy(
|
|
devices_.begin(),
|
|
devices_.end(),
|
|
std::ostream_iterator<c10::Device>(oss, ", "));
|
|
responseMessage = createExceptionResponse(
|
|
c10::str(
|
|
"RPC detected that a user-function output tensor on device ",
|
|
device,
|
|
". This device is not one of the input tensor devices: ",
|
|
oss.str(),
|
|
"which is not yet supported. Please file a feature request "
|
|
"issue in PyTorch GitHub repo."),
|
|
static_cast<int64_t>(messageId));
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
pipeWrite(
|
|
pipe,
|
|
responseMessage,
|
|
std::move(devices),
|
|
std::move(streams),
|
|
[this, pipe, messageId](const tensorpipe::Error& error) {
|
|
if (error) {
|
|
LOG(WARNING)
|
|
<< "RPC agent for " << workerInfo_.name_
|
|
<< " encountered error when sending response to request #"
|
|
<< messageId << " to " << pipe->getRemoteName() << ": "
|
|
<< error.what();
|
|
return;
|
|
}
|
|
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
|
<< " done sending response to request #" << messageId
|
|
<< " to " << pipe->getRemoteName();
|
|
});
|
|
} else {
|
|
pipeWrite(
|
|
pipe,
|
|
createExceptionResponse(
|
|
futureResponseMessage.tryRetrieveErrorMessage(),
|
|
static_cast<int64_t>(messageId)),
|
|
/* devices */ {},
|
|
std::move(streams),
|
|
[this, pipe, messageId](const tensorpipe::Error& error) {
|
|
if (error) {
|
|
LOG(WARNING)
|
|
<< "RPC agent for " << workerInfo_.name_
|
|
<< " encountered error when sending response to request #"
|
|
<< messageId << " to " << pipe->getRemoteName() << ": "
|
|
<< error.what();
|
|
return;
|
|
}
|
|
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
|
<< " done sending response to request #" << messageId
|
|
<< " to " << pipe->getRemoteName();
|
|
});
|
|
}
|
|
}
|
|
|
|
void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
|
|
pipeRead(
|
|
pipe,
|
|
[this, pipe](
|
|
const tensorpipe::Error& error,
|
|
c10::intrusive_ptr<Message> requestMessage,
|
|
std::vector<c10::Stream> streams) mutable {
|
|
if (error) {
|
|
if (shuttingDown_) {
|
|
// This is expected.
|
|
} else {
|
|
LOG(WARNING)
|
|
<< "RPC agent for " << workerInfo_.name_
|
|
<< " encountered error when reading incoming request from "
|
|
<< pipe->getRemoteName() << ": " << error.what();
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Arm for next read
|
|
respond(pipe);
|
|
|
|
uint64_t messageId = requestMessage->id();
|
|
increaseCallCount(serverActiveCalls_);
|
|
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
|
<< " received request #" << messageId << " from "
|
|
<< pipe->getRemoteName();
|
|
|
|
// Defer user RPC UDF run to thread pool
|
|
threadPool_.run([this,
|
|
pipe,
|
|
messageId,
|
|
requestMessage{std::move(requestMessage)},
|
|
streams{std::move(streams)}]() mutable {
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
|
<< " is running request #" << messageId << " from "
|
|
<< pipe->getRemoteName() << " in thread pool";
|
|
|
|
c10::intrusive_ptr<JitFuture> futureResponseMessage;
|
|
try {
|
|
// Instead of creating a MultiStreamGuard here, the ctx is passed
|
|
// to the callback and the MultiStreamGuard is created there,
|
|
// because subsequent processing can switch threads due to 1)
|
|
// waiting for RRef arguments to become ready 2) async_execution.
|
|
// Besides, the `ctx` also needs to be propagated to
|
|
// `process***Call` methods to synchronize CUDA streams there
|
|
// to make sure that we fetch the correct value from `to_here()`
|
|
// call.
|
|
futureResponseMessage =
|
|
cb_->operator()(*requestMessage, std::move(streams));
|
|
} catch (const std::exception& /* unused */) {
|
|
futureResponseMessage =
|
|
c10::make_intrusive<JitFuture>(at::AnyClassType::get());
|
|
futureResponseMessage->setError(std::current_exception());
|
|
}
|
|
|
|
increaseCallCount(serverActiveAsyncCalls_);
|
|
futureResponseMessage->addCallback(
|
|
[this, pipe, messageId](
|
|
JitFuture& futureResponseMessage) mutable {
|
|
decreaseCallCount(serverActiveCalls_);
|
|
decreaseCallCount(serverActiveAsyncCalls_);
|
|
auto streams = getCurrentStreamsForDevices(
|
|
futureResponseMessage.devices());
|
|
sendCompletedResponseMessage(
|
|
pipe, futureResponseMessage, messageId, std::move(streams));
|
|
});
|
|
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
|
<< " done running request #" << messageId << " from "
|
|
<< pipe->getRemoteName() << " in thread pool";
|
|
});
|
|
});
|
|
}
|
|
|
|
c10::intrusive_ptr<JitFuture> TensorPipeAgent::send(
|
|
const WorkerInfo& toWorkerInfo,
|
|
c10::intrusive_ptr<Message> requestMessage,
|
|
const float rpcTimeoutSeconds,
|
|
const DeviceMap& deviceMap) {
|
|
TORCH_CHECK(
|
|
requestMessage->isRequest(),
|
|
"TensorPipeAgent::send(..) is only for sending requests.");
|
|
|
|
if (!rpcAgentRunning_.load()) {
|
|
auto err = c10::str(
|
|
"Node ",
|
|
RpcAgent::getWorkerInfo().id_,
|
|
"tried to send() a message of type ",
|
|
requestMessage->type(),
|
|
" but RPC is no longer running on this node.");
|
|
TORCH_CHECK(false, err);
|
|
}
|
|
|
|
const auto& url = findWorkerURL(toWorkerInfo);
|
|
|
|
decltype(connectedPipes_)::iterator it;
|
|
{
|
|
std::unique_lock<std::mutex> lock(connectedPipesMutex_);
|
|
|
|
// See if we already have a connection to this address or not
|
|
it = connectedPipes_.find(toWorkerInfo.id_);
|
|
if (it == connectedPipes_.end()) {
|
|
// An instance of ClientPipe cannot be copied or moved as it contains a
|
|
// mutex, and to force in-place construction in GCC 5 we need piecewise
|
|
// construction in order to work around an issue.
|
|
it = connectedPipes_
|
|
.emplace(
|
|
std::piecewise_construct,
|
|
std::forward_as_tuple(toWorkerInfo.id_),
|
|
std::forward_as_tuple(context_->connect(
|
|
url,
|
|
tensorpipe::PipeOptions().remoteName(
|
|
toWorkerInfo.name_))))
|
|
.first;
|
|
}
|
|
}
|
|
ClientPipe& clientPipe = it->second;
|
|
|
|
std::shared_ptr<torch::distributed::rpc::TensorPipeAgent::AtomicJitFuture>
|
|
futureResponseMessage;
|
|
{
|
|
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
|
|
futureResponseMessage = std::make_shared<AtomicJitFuture>(devices_);
|
|
}
|
|
uint64_t messageId = nextMessageID_++;
|
|
requestMessage->setId(static_cast<int64_t>(messageId));
|
|
|
|
{
|
|
std::unique_lock<std::mutex> lock(clientPipe.mutex_);
|
|
clientPipe.pendingResponseMessage_[messageId] = futureResponseMessage;
|
|
}
|
|
|
|
// Get devices for tensors in the request message. This can throw if device
|
|
// maps are not configured properly for this request.
|
|
std::vector<c10::Device> devices;
|
|
if (deviceMap.empty()) {
|
|
devices =
|
|
getDevicesForRemote(clientPipe.pipe_->getRemoteName(), *requestMessage);
|
|
} else {
|
|
// If deviceMap is specified, use that instead.
|
|
devices = getDevicesForTensors(
|
|
requestMessage->tensors(),
|
|
deviceMap,
|
|
clientPipe.pipe_->getRemoteName());
|
|
}
|
|
|
|
futureResponseMessage->jitFuture->addCallback(
|
|
[this](JitFuture& /* unused */) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
this->threadPool_.inThreadPool(),
|
|
"Future marked complete from outside the thread pool");
|
|
});
|
|
|
|
increaseCallCount(clientActiveCalls_);
|
|
// Use the default RPC timeout if no timeout is specified for this send call
|
|
auto timeout = rpcTimeoutSeconds == kUnsetRpcTimeout
|
|
? getRpcTimeout()
|
|
: std::chrono::milliseconds(
|
|
static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
|
|
|
|
// We only add to the timeoutMap_ if the timeout is not 0. Per our
|
|
// documentation, a user-provided timeout of 0 indicates the RPC should never
|
|
// expire (infinite timeout), so there is no need to track it in the
|
|
// timeoutMap_.
|
|
steady_clock_time_point expirationTime;
|
|
if (timeout.count() != 0) {
|
|
// Compute the expiration time for this message based on the timeout
|
|
expirationTime = computeRpcMessageExpiryTime(timeout);
|
|
|
|
// Add the Future to the right vector in the timeoutMap_
|
|
{
|
|
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
|
|
auto& timeoutFuturesVector = timeoutMap_[expirationTime];
|
|
messageIdToTimeout_.emplace(messageId, expirationTime);
|
|
timeoutFuturesVector.emplace_back(
|
|
messageId, futureResponseMessage, timeout);
|
|
}
|
|
timeoutThreadCV_.notify_one();
|
|
}
|
|
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #"
|
|
<< messageId << " to " << clientPipe.pipe_->getRemoteName();
|
|
|
|
std::vector<c10::Stream> streams;
|
|
{
|
|
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
|
|
streams = getStreamsFromPoolForDevices(devices_);
|
|
}
|
|
makeStreamsWaitOnOthers(
|
|
streams,
|
|
getCurrentStreamsForDevices(
|
|
getDevicesOfTensors(requestMessage->tensors())));
|
|
pipeWrite(
|
|
clientPipe.pipe_,
|
|
requestMessage,
|
|
std::move(devices),
|
|
std::move(streams),
|
|
[this, &clientPipe, messageId](const tensorpipe::Error& error) mutable {
|
|
if (error) {
|
|
if (error.isOfType<tensorpipe::PipeClosedError>() &&
|
|
!rpcAgentRunning_.load()) {
|
|
// This is expected.
|
|
} else {
|
|
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
|
|
<< " encountered error when sending outgoing request #"
|
|
<< messageId << " to "
|
|
<< clientPipe.pipe_->getRemoteName() << ": "
|
|
<< error.what();
|
|
}
|
|
handleClientError(clientPipe, error);
|
|
return;
|
|
}
|
|
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " sent request #"
|
|
<< messageId << " to " << clientPipe.pipe_->getRemoteName();
|
|
|
|
pipeRead(
|
|
clientPipe.pipe_,
|
|
[this, &clientPipe](
|
|
const tensorpipe::Error& error,
|
|
c10::intrusive_ptr<Message> responseMessage,
|
|
std::vector<c10::Stream> streams) {
|
|
if (error) {
|
|
if (error.isOfType<tensorpipe::PipeClosedError>() &&
|
|
!rpcAgentRunning_.load()) {
|
|
// This is expected.
|
|
} else {
|
|
LOG(WARNING)
|
|
<< "RPC agent for " << workerInfo_.name_
|
|
<< " encountered error when reading incoming response from "
|
|
<< clientPipe.pipe_->getRemoteName() << ": "
|
|
<< error.what();
|
|
}
|
|
handleClientError(clientPipe, error);
|
|
return;
|
|
}
|
|
|
|
// Identify future response message by message ID
|
|
uint64_t messageId = responseMessage->id();
|
|
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
|
<< " received response #" << messageId << " from "
|
|
<< clientPipe.pipe_->getRemoteName();
|
|
|
|
std::shared_ptr<AtomicJitFuture> futureResponseMessage;
|
|
{
|
|
std::lock_guard<std::mutex> lock(clientPipe.mutex_);
|
|
// A read error will lead all following callbacks to be
|
|
// invoked with error, and shouldn't reach here.
|
|
TORCH_INTERNAL_ASSERT(
|
|
!clientPipe.inError_, "Shouldn't be in error state");
|
|
auto it = clientPipe.pendingResponseMessage_.find(messageId);
|
|
TORCH_INTERNAL_ASSERT(
|
|
it != clientPipe.pendingResponseMessage_.end(),
|
|
"message ID ",
|
|
messageId,
|
|
" is not recognized");
|
|
futureResponseMessage = std::move(it->second);
|
|
clientPipe.pendingResponseMessage_.erase(it);
|
|
}
|
|
|
|
// Remove entry from timeoutMap_.
|
|
removeFromTimeoutMap(messageId);
|
|
|
|
if (responseMessage->type() == MessageType::EXCEPTION) {
|
|
markFutureWithError(
|
|
std::move(futureResponseMessage),
|
|
std::string(
|
|
responseMessage->payload().begin(),
|
|
responseMessage->payload().end()));
|
|
} else {
|
|
markFutureAsComplete(
|
|
std::move(futureResponseMessage),
|
|
std::move(responseMessage),
|
|
std::move(streams));
|
|
}
|
|
});
|
|
});
|
|
|
|
return futureResponseMessage->jitFuture;
|
|
}
|
|
|
|
void TensorPipeAgent::handleClientError(
|
|
ClientPipe& clientPipe,
|
|
const tensorpipe::Error& error) {
|
|
// When an error occurs on a pipe all pending operations will be aborted and
|
|
// all callbacks invoked with error, hence we immediately flush all future
|
|
// messages belonging to this pipe.
|
|
decltype(clientPipe.pendingResponseMessage_) pendingMsgs;
|
|
{
|
|
std::lock_guard<std::mutex> lock(clientPipe.mutex_);
|
|
std::swap(clientPipe.pendingResponseMessage_, pendingMsgs);
|
|
clientPipe.inError_ = true;
|
|
}
|
|
std::string errorMsg = error.what();
|
|
for (auto& p : pendingMsgs) {
|
|
markFutureWithError(std::move(p.second), errorMsg);
|
|
|
|
// Remove entry from timeoutMap_.
|
|
removeFromTimeoutMap(p.first);
|
|
}
|
|
}
|
|
|
|
void TensorPipeAgent::pollTimeoutRpcs() {
|
|
while (rpcAgentRunning_.load()) {
|
|
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
|
|
|
|
// We sleep until the earliest expiring RPC in the timeoutMap_. We must
|
|
// also ensure that we sleep while the map is empty, and we exit sleeping
|
|
// if the RPC Agent has been shutdown.
|
|
for (;;) {
|
|
if (!rpcAgentRunning_.load()) {
|
|
return;
|
|
}
|
|
|
|
if (!timeoutMap_.empty()) {
|
|
steady_clock_time_point earliestTimeout = timeoutMap_.begin()->first;
|
|
if (std::chrono::steady_clock::now() >= earliestTimeout) {
|
|
break;
|
|
}
|
|
timeoutThreadCV_.wait_until(lock, earliestTimeout);
|
|
} else {
|
|
timeoutThreadCV_.wait(lock);
|
|
}
|
|
}
|
|
|
|
// Move all these futures to a separate vector so we can process them
|
|
// outside the lock.
|
|
std::vector<TimeoutMessageMetadata> timedOutFutures =
|
|
std::move(timeoutMap_.begin()->second);
|
|
|
|
// We can safely remove this key from the timeoutMap_ since all these
|
|
// futures will be processed.
|
|
timeoutMap_.erase(timeoutMap_.begin());
|
|
|
|
for (auto& timeoutMetadata : timedOutFutures) {
|
|
// Remove from messageIdToTimeout map.
|
|
messageIdToTimeout_.erase(timeoutMetadata.messageId);
|
|
}
|
|
lock.unlock();
|
|
|
|
// Set an error on futures added to the timedOutFutures vector. We do this
|
|
// outside the lock to prevent potential lock-order-inversions by callbacks
|
|
// triggered by the setError call.
|
|
for (auto& timeoutMetadata : timedOutFutures) {
|
|
std::string errorMsg =
|
|
fmt::format(kRpcTimeoutErrorStr, timeoutMetadata.timeout.count());
|
|
auto err = makeRPCError(errorMsg, RPCErrorType::TIMEOUT);
|
|
markFutureWithError(
|
|
std::move(timeoutMetadata.responseFuture), std::move(err));
|
|
}
|
|
}
|
|
}
|
|
|
|
void TensorPipeAgent::leaveGroup() {
|
|
std::unique_lock<std::mutex> lock(callCountMutex_);
|
|
// local worker ActiveCallCount is 0 at this point and we will shutdown
|
|
// (any future calls will be dropped)
|
|
callCountCV_.wait(lock, [this] { return clientActiveCalls_ == 0; });
|
|
|
|
// Remove this agent's WorkerInfo from store
|
|
removeCurrentName(rankToNameStore_, workerInfo_.id_, workerInfo_.name_);
|
|
|
|
// Set internal variable to be used during destructor
|
|
shuttingDown_ = true;
|
|
}
|
|
|
|
// TODO: Remove join()
|
|
void TensorPipeAgent::join(bool shutdown, float /* unused */) {
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is joining";
|
|
if (!isStaticGroup_) {
|
|
leaveGroup();
|
|
return;
|
|
}
|
|
|
|
// This method behaves like a barrier, as it can only return once all workers
|
|
// have no more requests pending, including "nested" requests (triggered from
|
|
// within the remote code of another call) and "follow-up" requests (triggered
|
|
// from the callback of a future).
|
|
while (true) {
|
|
{
|
|
std::unique_lock<std::mutex> lock(callCountMutex_);
|
|
// It is enough to wait for there to be no more active client calls, since
|
|
// each server call corresponds to a client call for some other worker.
|
|
callCountCV_.wait(lock, [this] { return clientActiveCalls_ == 0; });
|
|
|
|
// We'd like to immediately proceed with the allreduce, but it's a call
|
|
// that may block for some time, as it waits for other workers to also
|
|
// complete all their active client calls. While we call allreduce we must
|
|
// hold the mutex, or else the count we send to other workers may get
|
|
// stale (e.g., if some nested call happens in the meantime). But we can't
|
|
// hold the lock for an indeterminately long time, as that would block
|
|
// other operations (e.g., send). Thus we must release the lock and only
|
|
// re-acquire it when all workers are ready to proceed with the allreduce.
|
|
// We perform this synchronization using a barrier.
|
|
}
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
|
<< " completed all client calls and is entering a barrier";
|
|
syncCallCount(shutdownStore_, worldSize_);
|
|
{
|
|
std::unique_lock<std::mutex> lock(callCountMutex_);
|
|
// At this point, the count may have become non-zero again. We can't wait
|
|
// for those calls to complete as other workers are waiting for us in the
|
|
// allreduce and we would block them. Thus we send our count even if it is
|
|
// non-zero and if anyone (be it us or another worker) has a non-zero
|
|
// count we'll just do another round.
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
|
<< " exited the barrier and found " << clientActiveCalls_
|
|
<< " active client calls";
|
|
int totalClientActiveCalls =
|
|
syncCallCount(shutdownStore_, worldSize_, clientActiveCalls_);
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
|
<< " completed sync call counts and got a total of "
|
|
<< totalClientActiveCalls
|
|
<< " active client calls across all workers";
|
|
if (totalClientActiveCalls == 0) {
|
|
if (shutdown) {
|
|
shuttingDown_ = true;
|
|
syncCallCount(shutdownStore_, worldSize_);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " done joining";
|
|
}
|
|
|
|
void TensorPipeAgent::shutdownImpl() {
|
|
// FIXME Isn't it too verbose for a library to print logs in normal operation?
|
|
LOG(INFO) << "RPC agent for " << workerInfo_.name_ << " is shutting down";
|
|
|
|
// Join the Timeout Thread
|
|
timeoutThreadCV_.notify_one();
|
|
if (timeoutThread_.joinable()) {
|
|
timeoutThread_.join();
|
|
}
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
|
<< " done waiting for timeout thread to join";
|
|
|
|
// This will close all the pipes and listeners, invoke all callbacks with
|
|
// errors, turn down the I/O event loops and wait for everything to terminate.
|
|
context_->join();
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
|
<< " done waiting for TensorPipe context to join";
|
|
|
|
// NOTE: We need to call waitWorkComplete in the end after we have shutdown
|
|
// all listeners for Tensorpipe. This is to drain any already accepted work
|
|
// in the ThreadPool. If this is done before we shutdown the listeners,
|
|
// additional work could be added after this call and before we shutdown
|
|
// listeners. This work would continue executing in the threadpool and might
|
|
// cause issues during shutdown of the system.
|
|
threadPool_.waitWorkComplete();
|
|
VLOG(1) << "RPC agent for " << workerInfo_.name_
|
|
<< " done waiting for thread pool to complete work";
|
|
}
|
|
|
|
const WorkerInfo& TensorPipeAgent::getWorkerInfo(
|
|
const std::string& workerName) const {
|
|
std::unordered_map<std::string, WorkerInfo>::const_iterator it;
|
|
{
|
|
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
|
|
it = workerNameToInfo_.find(workerName);
|
|
}
|
|
TORCH_CHECK(
|
|
it != workerNameToInfo_.end(),
|
|
fmt::format(
|
|
"name:{},rank:{} could not find destination name {}",
|
|
workerInfo_.name_,
|
|
workerInfo_.id_,
|
|
workerName));
|
|
return it->second;
|
|
}
|
|
|
|
const WorkerInfo& TensorPipeAgent::getWorkerInfo(worker_id_t workerId) const {
|
|
std::unordered_map<worker_id_t, WorkerInfo>::const_iterator it;
|
|
{
|
|
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
|
|
it = workerIdToInfo_.find(workerId);
|
|
}
|
|
TORCH_CHECK(
|
|
it != workerIdToInfo_.end(),
|
|
fmt::format(
|
|
"name:{},rank:{} could not find destination id {}",
|
|
workerInfo_.name_,
|
|
workerInfo_.id_,
|
|
workerId));
|
|
return it->second;
|
|
}
|
|
|
|
std::vector<WorkerInfo> TensorPipeAgent::getWorkerInfos() const {
|
|
std::vector<WorkerInfo> workerInfos;
|
|
workerInfos.reserve(workerNameToInfo_.size());
|
|
for (auto& item : workerNameToInfo_) {
|
|
workerInfos.emplace_back(item.second);
|
|
}
|
|
return workerInfos;
|
|
}
|
|
|
|
const std::string& TensorPipeAgent::findWorkerURL(
|
|
const WorkerInfo& worker) const {
|
|
std::unordered_map<std::string, std::string>::const_iterator it;
|
|
{
|
|
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
|
|
it = workerNameToURL_.find(worker.name_);
|
|
}
|
|
TORCH_CHECK(
|
|
it != workerNameToURL_.end(),
|
|
fmt::format(
|
|
"name:{},rank:{} could not find destination url for name {}",
|
|
workerInfo_.name_,
|
|
workerInfo_.id_,
|
|
worker.name_));
|
|
return it->second;
|
|
}
|
|
|
|
void TensorPipeAgent::updateGroupMembership(
|
|
const WorkerInfo& workerInfo,
|
|
const std::vector<c10::Device>& devices,
|
|
const std::unordered_map<std::string, DeviceMap>& reverseDeviceMaps,
|
|
bool isJoin) {
|
|
std::string name = workerInfo.name_;
|
|
worker_id_t id = workerInfo.id_;
|
|
// Rank with workerInfo is joining the group, update internal mappings
|
|
if (isJoin) {
|
|
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
|
|
workerIdToInfo_.emplace(id, workerInfo);
|
|
workerNameToInfo_.emplace(name, workerInfo);
|
|
|
|
// TODO: we should get nodeAddrStr in the joining process, then pass in as
|
|
// an argument rather than getting from store each time
|
|
auto nodeAddrData = nameToAddressStore_.get(name);
|
|
auto nodeAddrStr = std::string(
|
|
reinterpret_cast<const char*>(nodeAddrData.data()),
|
|
nodeAddrData.size());
|
|
workerNameToURL_.insert({name, nodeAddrStr});
|
|
|
|
for (const auto& it : reverseDeviceMaps) {
|
|
if (reverseDeviceMaps_.find(it.first) == reverseDeviceMaps_.end()) {
|
|
reverseDeviceMaps_[it.first] = it.second;
|
|
}
|
|
}
|
|
// TODO: clean up mutex for devices_ usage
|
|
// Add devices that have not been added yet
|
|
for (const auto& it : devices) {
|
|
if (std::find(devices_.begin(), devices_.end(), it) == devices_.end()) {
|
|
devices_.push_back(it);
|
|
}
|
|
}
|
|
} else {
|
|
workerIdToInfo_.erase(id);
|
|
workerNameToInfo_.erase(name);
|
|
workerNameToURL_.erase(name);
|
|
|
|
// remove reverse device maps that are no longer used
|
|
for (auto it = reverseDeviceMaps_.begin();
|
|
it != reverseDeviceMaps_.end();) {
|
|
if (reverseDeviceMaps.find(it->first) == reverseDeviceMaps.end()) {
|
|
it = reverseDeviceMaps_.erase(it);
|
|
} else {
|
|
it++;
|
|
}
|
|
}
|
|
|
|
// remove devices that are no longer used
|
|
for (auto it = devices_.begin(); it != devices_.end();) {
|
|
if (std::find(devices.begin(), devices.end(), *it) == devices.end()) {
|
|
it = devices_.erase(it);
|
|
} else {
|
|
it++;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
std::unordered_map<std::string, std::string> TensorPipeAgent::getMetrics() {
|
|
std::unordered_map<std::string, std::string> metrics;
|
|
metrics[kThreadPoolSize] = std::to_string(threadPool_.size());
|
|
metrics[kNumIdleThreads] = std::to_string(threadPool_.numAvailable());
|
|
{
|
|
std::unique_lock<std::mutex> lock(callCountMutex_);
|
|
metrics[kClientActiveCalls] = std::to_string(clientActiveCalls_);
|
|
metrics[kServerActiveCalls] = std::to_string(serverActiveCalls_);
|
|
metrics[kServerActiveAsyncCalls] = std::to_string(serverActiveAsyncCalls_);
|
|
}
|
|
if (isGILProfilingEnabled()) {
|
|
{
|
|
std::unique_lock<std::mutex> lock(metricsMutex_);
|
|
// Include the averages for each time series metric. This is just the GIL
|
|
// Wait Time for now.
|
|
auto averageGilWaitTime =
|
|
timeSeriesMetrics_[kGilAverageWaitTime].computeAverage();
|
|
lock.unlock();
|
|
metrics[kGilAverageWaitTime] = std::to_string(averageGilWaitTime);
|
|
}
|
|
}
|
|
|
|
return metrics;
|
|
}
|
|
|
|
void TensorPipeAgent::addGilWaitTime(
|
|
const std::chrono::microseconds gilWaitTime) {
|
|
std::lock_guard<std::mutex> lock(metricsMutex_);
|
|
timeSeriesMetrics_[kGilAverageWaitTime].addData(gilWaitTime.count());
|
|
}
|
|
|
|
TensorPipeAgent::NetworkDataDict TensorPipeAgent::getNetworkData() {
|
|
std::lock_guard<std::mutex> lock(networkDataMutex_);
|
|
return networkData_;
|
|
}
|
|
|
|
NetworkSourceInfo TensorPipeAgent::getNetworkSourceInfo() {
|
|
NetworkSourceInfo info = {
|
|
RpcAgent::getWorkerInfo().id_,
|
|
nameToAddressStore_.get(RpcAgent::getWorkerInfo().name_)};
|
|
|
|
return info;
|
|
}
|
|
|
|
void TensorPipeAgent::trackNetworkData(
|
|
uint64_t requestSize,
|
|
uint64_t responseSize,
|
|
const std::string& destWorkerName) {
|
|
std::lock_guard<std::mutex> lock(networkDataMutex_);
|
|
networkData_[destWorkerName].numCalls++;
|
|
networkData_[destWorkerName].totalSentBytes += requestSize;
|
|
networkData_[destWorkerName].totalRecvBytes += responseSize;
|
|
}
|
|
|
|
void TensorPipeAgent::trackNetworkError(
|
|
uint64_t requestSize,
|
|
const std::string& destWorkerName) {
|
|
std::lock_guard<std::mutex> lock(networkDataMutex_);
|
|
networkData_[destWorkerName].numCalls++;
|
|
networkData_[destWorkerName].totalSentBytes += requestSize;
|
|
networkData_[destWorkerName].totalErrors++;
|
|
}
|
|
|
|
void TensorPipeAgent::increaseCallCount(int32_t& count) {
|
|
{
|
|
std::unique_lock<std::mutex> lock(callCountMutex_);
|
|
++count;
|
|
}
|
|
callCountCV_.notify_all();
|
|
}
|
|
|
|
void TensorPipeAgent::decreaseCallCount(int32_t& count) {
|
|
{
|
|
std::unique_lock<std::mutex> lock(callCountMutex_);
|
|
--count;
|
|
}
|
|
callCountCV_.notify_all();
|
|
}
|
|
|
|
void TensorPipeAgent::markFutureAsComplete(
|
|
std::shared_ptr<AtomicJitFuture> atomicFuture,
|
|
c10::intrusive_ptr<Message> message,
|
|
std::vector<c10::Stream> streams) {
|
|
if (!atomicFuture->isComplete.test_and_set()) {
|
|
// Completing the future will run its callbacks, which could execute
|
|
// arbitrary user code. To prevent blocking or stalling the TensorPipe event
|
|
// loops, we defer this to a worker thread.
|
|
threadPool_.run([this,
|
|
atomicFuture{std::move(atomicFuture)},
|
|
message{std::move(message)},
|
|
streams{std::move(streams)}]() mutable {
|
|
c10::MultiStreamGuard guard(streams);
|
|
std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages =
|
|
message->getStorages();
|
|
atomicFuture->jitFuture->markCompleted(
|
|
std::move(message), std::move(storages));
|
|
// The future's callbacks may schedule further RPCs, increasing the count.
|
|
// Thus we must decrease it after completing the future, otherwise it may
|
|
// briefly dip to zero and trick join into thinking all work is done.
|
|
decreaseCallCount(clientActiveCalls_);
|
|
});
|
|
}
|
|
}
|
|
|
|
void TensorPipeAgent::markFutureWithError(
|
|
std::shared_ptr<AtomicJitFuture> atomicFuture,
|
|
std::string errorMsg) {
|
|
if (!atomicFuture->isComplete.test_and_set()) {
|
|
// Completing the future will run its callbacks, which could execute
|
|
// arbitrary user code. To prevent blocking or stalling the TensorPipe event
|
|
// loops, we defer this to a worker thread.
|
|
threadPool_.run([this,
|
|
atomicFuture{std::move(atomicFuture)},
|
|
errorMsg{std::move(errorMsg)}]() mutable {
|
|
atomicFuture->jitFuture->setError(
|
|
std::make_exception_ptr(std::runtime_error(errorMsg)));
|
|
// The future's callbacks may schedule further RPCs, increasing the count.
|
|
// Thus we must decrease it after completing the future, otherwise it may
|
|
// briefly dip to zero and trick join into thinking all work is done.
|
|
decreaseCallCount(clientActiveCalls_);
|
|
});
|
|
}
|
|
}
|
|
|
|
std::vector<c10::Device> TensorPipeAgent::getDevicesForRemote(
|
|
const std::string& remoteName,
|
|
const Message& message) const {
|
|
std::unordered_map<std::string, DeviceMap> deviceMaps;
|
|
{
|
|
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
|
|
deviceMaps = message.isRequest() ? opts_.deviceMaps : reverseDeviceMaps_;
|
|
}
|
|
|
|
const auto errStr = c10::str(
|
|
"TensorPipe RPC backend only supports CPU tensors by default, please "
|
|
"move your tensors to CPU before sending them over RPC, or call "
|
|
"`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly "
|
|
"configure device mapping. ",
|
|
message.isRequest() ? "Request" : "Response",
|
|
" device mapping is not available for destination ",
|
|
remoteName);
|
|
|
|
const auto& iter = deviceMaps.find(remoteName);
|
|
if (iter == deviceMaps.end()) {
|
|
for (const auto& t : message.tensors()) {
|
|
TORCH_CHECK(
|
|
t.device().is_cpu(),
|
|
errStr,
|
|
", but found tensor on device: ",
|
|
t.device());
|
|
}
|
|
return {};
|
|
} else {
|
|
return getDevicesForTensors(message.tensors(), iter->second, errStr);
|
|
}
|
|
}
|
|
|
|
DeviceMap TensorPipeAgent::getDeviceMap(const WorkerInfo& dst) const {
|
|
auto it = opts_.deviceMaps.find(dst.name_);
|
|
if (it == opts_.deviceMaps.end()) {
|
|
return {};
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
const c10::intrusive_ptr<::c10d::Store> TensorPipeAgent::getStore() const {
|
|
return store_;
|
|
}
|
|
|
|
TensorPipeRpcBackendOptions TensorPipeAgent::getBackendOptions() const {
|
|
return opts_;
|
|
}
|
|
|
|
const std::vector<c10::Device>& TensorPipeAgent::getDevices() const {
|
|
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
|
|
return devices_;
|
|
}
|
|
|
|
size_t TensorPipeAgent::timeoutMapSize() {
|
|
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
|
|
return timeoutMap_.size();
|
|
}
|
|
|
|
size_t TensorPipeAgent::numPendingResponses() {
|
|
std::unique_lock<std::mutex> lock(callCountMutex_);
|
|
return clientActiveCalls_;
|
|
}
|
|
|
|
size_t TensorPipeAgent::messageIdToTimeoutMapSize() {
|
|
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
|
|
return messageIdToTimeout_.size();
|
|
}
|
|
|
|
} // namespace torch::distributed::rpc
|
|
|
|
#endif // USE_TENSORPIPE
|