Files
pytorch/torch/csrc/distributed/rpc/tensorpipe_agent.cpp
Yuanyuan Chen e1e8491b31 [1/N] Change C-style casts to static_cast or reinterpret_cast (#165750)
This series of changes try to cover C style casts into C++ alternatives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165750
Approved by: https://github.com/Skylion007
2025-10-20 04:36:19 +00:00

1483 lines
53 KiB
C++

#include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
#ifdef USE_TENSORPIPE
#include <limits>
#include <tuple>
#include <utility>
#include <fmt/format.h>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated")
#include <tensorpipe/tensorpipe.h>
C10_DIAGNOSTIC_POP()
#include <torch/csrc/distributed/rpc/agent_utils.h>
#include <torch/csrc/distributed/rpc/tensorpipe_utils.h>
#include <torch/csrc/distributed/rpc/utils.h>
#include <c10/core/StreamGuard.h>
#include <c10/util/irange.h>
namespace torch::distributed::rpc {
namespace {
// An environment variable along the lines of GLOO_ and NCCL_SOCKET_IFNAME that
// allows the user to specify a device to bind to, instead of binding to the
// address that the hostname resolves to.
const std::string kSocketIfnameEnvVar = "TP_SOCKET_IFNAME";
const std::string kDefaultUvAddress = "127.0.0.1";
const std::string kGilAverageWaitTime = "agent.gil_average_wait_time_us";
const std::string kThreadPoolSize = "agent.thread_pool_size";
const std::string kNumIdleThreads = "agent.num_idle_threads";
const std::string kClientActiveCalls = "agent.client_active_calls";
const std::string kServerActiveCalls = "agent.server_active_calls";
const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
std::vector<c10::Device> getDevicesForTensors(
const std::vector<torch::Tensor>& tensors,
const DeviceMap& deviceMap,
const std::string& remoteName) {
// If the deviceMap is overridden, use that instead.
const auto errStr = c10::str(
"TensorPipe RPC backend only supports CPU tensors by default, please "
"move your tensors to CPU before sending them over RPC, or call "
"`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly "
"configure device mapping. ",
"Request device mapping is not available for destination ",
remoteName);
std::vector<c10::Device> devices;
devices.reserve(tensors.size());
bool hasMappedDevice = false;
for (const auto& t : tensors) {
if (t.device().is_cpu()) {
const auto deviceIter = deviceMap.find(c10::kCPU);
if (deviceIter == deviceMap.end()) {
devices.emplace_back(c10::kCPU);
} else {
devices.emplace_back(deviceIter->second);
hasMappedDevice = true;
}
} else {
const auto deviceIter = deviceMap.find(t.device());
TORCH_CHECK(
deviceIter != deviceMap.end(),
errStr,
" for device ",
t.device(),
" but received a tensor on that device.");
devices.push_back(deviceIter->second);
hasMappedDevice = true;
}
}
if (!hasMappedDevice) {
devices.clear();
}
return devices;
}
std::vector<c10::Stream> getStreamsFromPoolForDevices(
const std::vector<c10::Device>& devices) {
if (devices.empty()) {
return {};
}
c10::impl::VirtualGuardImpl impl(devices[0].type());
std::vector<c10::Stream> streams;
streams.reserve(devices.size());
for (const c10::Device& device : devices) {
TORCH_INTERNAL_ASSERT(device.type() == impl.type());
streams.push_back(impl.getStreamFromGlobalPool(device));
}
return streams;
}
std::vector<c10::Stream> getCurrentStreamsForDevices(
const std::vector<c10::Device>& devices) {
if (devices.empty()) {
return {};
}
c10::impl::VirtualGuardImpl impl(devices[0].type());
std::vector<c10::Stream> streams;
streams.reserve(devices.size());
for (const c10::Device& device : devices) {
TORCH_INTERNAL_ASSERT(device.type() == impl.type());
streams.push_back(impl.getStream(device));
}
return streams;
}
std::vector<c10::Device> getDevicesOfTensors(
const std::vector<torch::Tensor>& tensors) {
std::optional<c10::impl::VirtualGuardImpl> impl;
size_t deviceCount = 0;
std::vector<bool> indexBitset;
for (const torch::Tensor& tensor : tensors) {
if (!tensor.is_cpu()) {
c10::Device device = tensor.device();
if (!impl.has_value()) {
impl.emplace(device.type());
indexBitset.resize(impl->deviceCount());
}
TORCH_INTERNAL_ASSERT(device.type() == impl->type());
TORCH_INTERNAL_ASSERT(device.has_index());
if (!indexBitset[device.index()]) {
deviceCount++;
indexBitset[device.index()] = true;
}
}
}
std::vector<c10::Device> devices;
devices.reserve(deviceCount);
for (const auto idx : c10::irange(indexBitset.size())) {
if (indexBitset[idx]) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
devices.emplace_back(impl->type(), static_cast<c10::DeviceIndex>(idx));
}
}
return devices;
}
void makeStreamsWaitOnOthers(
const std::vector<c10::Stream>& consumers,
const std::vector<c10::Stream>& producers) {
for (const c10::Stream& producer : producers) {
const c10::Stream& consumer =
getStreamForDevice(consumers, producer.device());
c10::Event event(producer.device_type());
event.record(producer);
event.block(consumer);
}
}
} // namespace
C10_DEFINE_REGISTRY_WITHOUT_WARNING(
TensorPipeTransportRegistry,
TransportRegistration)
C10_DEFINE_REGISTRY_WITHOUT_WARNING(
TensorPipeChannelRegistry,
ChannelRegistration)
const std::string& TensorPipeAgent::guessAddress() {
static const std::string uvAddress = []() {
auto ifnameEnv = c10::utils::get_env(kSocketIfnameEnvVar.c_str());
if (ifnameEnv.has_value()) {
auto [error, result] =
tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv.value());
if (error) {
LOG(WARNING) << "Failed to look up the IP address for interface "
<< ifnameEnv.value() << " (" << error.what()
<< "), defaulting to " << kDefaultUvAddress;
return kDefaultUvAddress;
}
return result;
}
auto [error, result] = tensorpipe::transport::uv::lookupAddrForHostname();
if (error) {
LOG(WARNING) << "Failed to look up the IP address for the hostname ("
<< error.what() << "), defaulting to " << kDefaultUvAddress;
return kDefaultUvAddress;
}
return result;
}();
return uvAddress;
}
namespace {
std::unique_ptr<TransportRegistration> makeUvTransport() {
auto context = tensorpipe::transport::uv::create();
std::string address = TensorPipeAgent::guessAddress();
return std::make_unique<TransportRegistration>(TransportRegistration{
std::move(context), kUvTransportPriority, std::move(address)});
}
// The UV transport is implemented using standard TCP connections. It leverages
// libuv (https://github.com/libuv/libuv) in order to be cross-platform.
C10_REGISTER_CREATOR(TensorPipeTransportRegistry, uv, makeUvTransport)
#if TENSORPIPE_HAS_SHM_TRANSPORT
std::unique_ptr<TransportRegistration> makeShmTransport() {
auto context = tensorpipe::transport::shm::create();
return std::make_unique<TransportRegistration>(
TransportRegistration{std::move(context), kShmTransportPriority, ""});
}
// The SHM implements connections using ringbuffers residing in anonymous shared
// memory (plus UNIX domain sockets to bootstrap the connection and exchange
// file descriptors). It is Linux-only due to some advanced features (O_TMPFILE,
// eventfd, ...).
C10_REGISTER_CREATOR(TensorPipeTransportRegistry, shm, makeShmTransport)
#endif // TENSORPIPE_HAS_SHM_TRANSPORT
#if TENSORPIPE_HAS_IBV_TRANSPORT
std::unique_ptr<TransportRegistration> makeIbvTransport() {
auto context = tensorpipe::transport::ibv::create();
std::string address = TensorPipeAgent::guessAddress();
return std::make_unique<TransportRegistration>(TransportRegistration{
std::move(context), kIbvTransportPriority, std::move(address)});
}
// The IBV transport sends data across using an InfiniBand queue pair, locally
// copying data to and from a staging buffer (registered with libibverbs) and
// issuing a RDMA write for transferring data across machines (plus a send for
// acknowledging it). It bootstraps using a standard TCP connection to exchange
// setup information. It is Linux-only.
C10_REGISTER_CREATOR(TensorPipeTransportRegistry, ibv, makeIbvTransport)
#endif // TENSORPIPE_HAS_IBV_TRANSPORT
std::unique_ptr<ChannelRegistration> makeBasicChannel() {
auto context = tensorpipe::channel::basic::create();
return std::make_unique<ChannelRegistration>(
ChannelRegistration{std::move(context), kBasicChannelPriority});
}
// The basic channel is just a straightforward adapter wrapper that allows any
// transport to be used as a channel.
C10_REGISTER_CREATOR(TensorPipeChannelRegistry, basic, makeBasicChannel)
#if TENSORPIPE_HAS_CMA_CHANNEL
std::unique_ptr<ChannelRegistration> makeCmaChannel() {
auto context = tensorpipe::channel::cma::create();
return std::make_unique<ChannelRegistration>(
ChannelRegistration{std::move(context), kCmaChannelPriority});
}
// The CMA channel uses the Linux cross-memory attach syscalls (process_vm_readv
// and _writev), which allow one process to access the private memory of another
// process (as long as they belong to the same user and other security
// constraints are satisfied). It does, more or less, what GDB does when it's
// attached to a running process.
C10_REGISTER_CREATOR(TensorPipeChannelRegistry, cma, makeCmaChannel)
#endif // TENSORPIPE_HAS_CMA_CHANNEL
constexpr static int kNumUvThreads = 16;
std::unique_ptr<ChannelRegistration> makeMultiplexedUvChannel() {
std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
contexts.reserve(kNumUvThreads);
std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
listeners.reserve(kNumUvThreads);
for ([[maybe_unused]] const auto laneIdx : c10::irange(kNumUvThreads)) {
auto context = tensorpipe::transport::uv::create();
const std::string& address = TensorPipeAgent::guessAddress();
contexts.push_back(std::move(context));
listeners.push_back(contexts.back()->listen(address));
}
auto context = tensorpipe::channel::mpt::create(
std::move(contexts), std::move(listeners));
return std::make_unique<ChannelRegistration>(
ChannelRegistration{std::move(context), kMultiplexedUvChannelPriority});
}
// The multiplexed UV channel encapsulates multiple UV transports (each with its
// own event loop thread). Each channel will, in turn, contain multiple UV
// connections, one for each of those contexts. When sending a tensor, its data
// is split in equal chunks and each chunks is sent on a different connection
// and thus driven by a different thread. This is needed to reach very high
// bandwidths.
C10_REGISTER_CREATOR(
TensorPipeChannelRegistry,
mpt_uv,
makeMultiplexedUvChannel)
} // namespace
////////////////////////// MetricsTracker /////////////////////////////////
TensorPipeAgent::TimeSeriesMetricsTracker::TimeSeriesMetricsTracker(
uint64_t currentSum,
uint64_t currentCount)
: currentSum_(currentSum), currentCount_(currentCount) {}
void TensorPipeAgent::TimeSeriesMetricsTracker::addData(uint64_t dataPoint) {
currentSum_ += dataPoint;
++currentCount_;
}
float TensorPipeAgent::TimeSeriesMetricsTracker::computeAverage() const {
return currentCount_ == 0 ? 0
: static_cast<float>(
static_cast<double>(currentSum_) /
static_cast<double>(currentCount_));
}
//////////////////////// TensorpipeRpcAgent /////////////////////////////////
void TensorPipeAgent::removeFromTimeoutMap(uint64_t messageId) {
// Remove entry from timeoutMap_.
{
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
auto it = messageIdToTimeout_.find(messageId);
if (it == messageIdToTimeout_.end()) {
// Already removed from the map by pollTimeoutRpcs(), no need to
// process further.
return;
}
auto& expirationTime = it->second;
auto& timedOutFuturesVector = timeoutMap_[expirationTime];
for (auto it = timedOutFuturesVector.begin();
it != timedOutFuturesVector.end();
it++) {
if (it->messageId == messageId) {
it = timedOutFuturesVector.erase(it);
break;
}
}
if (timedOutFuturesVector.empty()) {
timeoutMap_.erase(expirationTime);
}
// Remove from messageId to timeout map as well.
messageIdToTimeout_.erase(messageId);
}
}
void TensorPipeAgent::prepareNames(bool isStaticGroup) {
std::unordered_map<std::string, worker_id_t> nameToId;
if (isStaticGroup) {
nameToId = collectNames(
rankToNameStore_, workerInfo_.id_, workerInfo_.name_, worldSize_);
} else {
nameToId = collectCurrentNames(
rankToNameStore_, workerInfo_.id_, workerInfo_.name_);
}
for (const auto& entry : nameToId) {
const auto& workerName = entry.first;
const auto& workerId = entry.second;
workerIdToInfo_.emplace(workerId, WorkerInfo(workerName, workerId));
workerNameToInfo_.emplace(workerName, WorkerInfo(workerName, workerId));
}
}
void TensorPipeAgent::checkAndSetStaticGroup(
const c10::intrusive_ptr<::c10d::Store>& store) {
std::string isStaticGroupKey("rpcIsStaticGroup");
std::string isStaticGroupStr = isStaticGroup_ ? "true" : "false";
std::vector<uint8_t> isStaticGroupVec(
(uint8_t*)isStaticGroupStr.c_str(),
(uint8_t*)isStaticGroupStr.c_str() + isStaticGroupStr.length());
std::vector<uint8_t> returnedVec;
returnedVec = store->compareSet(
isStaticGroupKey, std::vector<uint8_t>(), isStaticGroupVec);
std::string returnedVal = std::string(returnedVec.begin(), returnedVec.end());
// In both cases, the returned value should be the value of isStaticGroupStr,
// otherwise there is a discrepancy with initialization among one of the
// members
TORCH_CHECK(
returnedVal == isStaticGroupStr,
fmt::format(
"RPC group mixes statically and dynamically initialized members which is not supported. ",
"Static group property is initialized as {} and is trying to be set as {} ",
isStaticGroup_,
returnedVal));
}
TensorPipeAgent::TensorPipeAgent(
const c10::intrusive_ptr<::c10d::Store>& store,
std::string selfName,
worker_id_t selfId,
std::optional<int> worldSize,
TensorPipeRpcBackendOptions opts,
std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
std::vector<c10::Device> devices,
std::unique_ptr<RequestCallback> cb)
: RpcAgent(
WorkerInfo(std::move(selfName), selfId),
std::move(cb),
std::chrono::milliseconds(
static_cast<long>(opts.rpcTimeoutSeconds * kSecToMsConversion))),
isStaticGroup_(worldSize.has_value()),
store_(store),
opts_(std::move(opts)),
reverseDeviceMaps_(std::move(reverseDeviceMaps)),
devices_(std::move(devices)),
threadPool_(opts_.numWorkerThreads),
context_(std::make_shared<tensorpipe::Context>(
tensorpipe::ContextOptions().name(workerInfo_.name_))),
rankToNameStore_("names", store),
nameToAddressStore_("addrs", store),
shutdownStore_("shutdown", store) {
if (isStaticGroup_) {
worldSize_ = worldSize.value();
}
// check the static group attribute against store
checkAndSetStaticGroup(store);
// collect worker names
prepareNames(isStaticGroup_);
// Initialize the time-series metrics tracking map
timeSeriesMetrics_.emplace(kGilAverageWaitTime, TimeSeriesMetricsTracker());
}
TensorPipeAgent::~TensorPipeAgent() {
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is being destroyed";
shutdown();
}
void TensorPipeAgent::startImpl() {
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is starting";
std::vector<std::string> addresses;
int64_t lowestPriority = std::numeric_limits<int64_t>::max();
std::string lowestPriorityTransport;
// Register transports
for (auto& key : TensorPipeTransportRegistry()->Keys()) {
int64_t priority = -1;
if (opts_.transports.has_value()) {
auto iter =
std::find(opts_.transports->begin(), opts_.transports->end(), key);
if (iter == opts_.transports->end()) {
continue;
}
// Assign priorities in reverse order of occurrence in the vector, so that
// a transport that comes before another receives a higher priority.
priority = static_cast<std::ptrdiff_t>(opts_.transports->size()) - 1 -
(iter - opts_.transports->begin());
}
std::unique_ptr<TransportRegistration> reg =
TensorPipeTransportRegistry()->Create(key);
if (!reg->transport->isViable()) {
continue;
}
if (priority == -1) {
priority = reg->priority;
}
if (priority < lowestPriority) {
lowestPriority = priority;
lowestPriorityTransport = key;
}
addresses.push_back(c10::str(key, "://", reg->address));
context_->registerTransport(priority, key, reg->transport);
}
// Register channels
for (auto& key : TensorPipeChannelRegistry()->Keys()) {
int64_t priority = -1;
if (opts_.channels.has_value()) {
auto iter =
std::find(opts_.channels->begin(), opts_.channels->end(), key);
if (iter == opts_.channels->end()) {
continue;
}
// Assign priorities in reverse order of occurrence in the vector, so
// that a channel that comes before another receives a higher priority.
priority = static_cast<std::ptrdiff_t>(opts_.channels->size()) - 1 -
(iter - opts_.channels->begin());
}
std::unique_ptr<ChannelRegistration> reg =
TensorPipeChannelRegistry()->Create(key);
if (!reg->channel->isViable()) {
continue;
}
if (priority == -1) {
priority = reg->priority;
}
context_->registerChannel(priority, key, reg->channel);
}
listener_ = context_->listen(addresses);
// Store our own url.
const auto address = listener_->url(lowestPriorityTransport);
nameToAddressStore_.set(workerInfo_.name_, address);
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is using address "
<< address;
for (const auto& p : workerNameToInfo_) {
const auto& name = p.first;
auto nodeAddrData = nameToAddressStore_.get(name);
auto nodeAddrStr = std::string(
reinterpret_cast<const char*>(nodeAddrData.data()),
nodeAddrData.size());
workerNameToURL_.insert({name, nodeAddrStr});
}
// Start the Timeout Thread
timeoutThread_ = std::thread(&TensorPipeAgent::pollTimeoutRpcs, this);
listener_->accept([this](
const tensorpipe::Error& error,
std::shared_ptr<tensorpipe::Pipe> pipe) {
onListenerAccepted(error, pipe);
});
}
void TensorPipeAgent::onListenerAccepted(
const tensorpipe::Error& error,
std::shared_ptr<tensorpipe::Pipe>& pipe) {
if (error) {
if (error.isOfType<tensorpipe::ListenerClosedError>() &&
!rpcAgentRunning_.load()) {
// This is expected.
} else {
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
<< " encountered error when accepting incoming pipe: "
<< error.what();
}
return;
}
// Accept the next connection request
listener_->accept([this](
const tensorpipe::Error& error,
std::shared_ptr<tensorpipe::Pipe> pipe) {
onListenerAccepted(error, pipe);
});
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " accepted incoming pipe from " << pipe->getRemoteName();
// Arm for server read
respond(pipe);
}
void TensorPipeAgent::pipeRead(
const std::shared_ptr<tensorpipe::Pipe>& pipe,
std::function<void(
const tensorpipe::Error&,
c10::intrusive_ptr<Message>,
std::vector<c10::Stream>)> fn) noexcept {
pipe->readDescriptor([this, fn{std::move(fn)}, pipe](
const tensorpipe::Error& error,
tensorpipe::Descriptor tpDescriptor) mutable {
if (error) {
fn(error, c10::intrusive_ptr<Message>(), {});
return;
}
std::vector<c10::Stream> streams;
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
streams = getStreamsFromPoolForDevices(devices_);
}
auto [tpAllocation, tpBuffers] = tensorpipeAllocate(tpDescriptor, streams);
pipe->read(
std::move(tpAllocation),
[tpDescriptor{std::move(tpDescriptor)},
tpBuffers{
std::make_shared<TensorpipeReadBuffers>(std::move(tpBuffers))},
fn{std::move(fn)},
streams{std::move(streams)}](const tensorpipe::Error& error) mutable {
if (error) {
fn(error, c10::intrusive_ptr<Message>(), {});
return;
}
// FIXME This does some unpickling, which could be a bit expensive:
// perhaps it would be best to perform it inside the worker threads?
c10::intrusive_ptr<Message> rpcMessage =
tensorpipeDeserialize(tpDescriptor, std::move(*tpBuffers));
fn(error, std::move(rpcMessage), std::move(streams));
});
});
}
void TensorPipeAgent::pipeWrite(
const std::shared_ptr<tensorpipe::Pipe>& pipe,
const c10::intrusive_ptr<Message>& rpcMessage,
std::vector<c10::Device>&& devices,
std::vector<c10::Stream> streams,
std::function<void(const tensorpipe::Error&)> fn) noexcept {
auto [tpMessage, tpBuffers] =
tensorpipeSerialize(rpcMessage, std::move(devices), streams);
pipe->write(
std::move(tpMessage),
[tpBuffers{
std::make_shared<TensorpipeWriteBuffers>(std::move(tpBuffers))},
fn{std::move(fn)},
streams{std::move(streams)}](const tensorpipe::Error& error) {
fn(error);
});
}
void TensorPipeAgent::sendCompletedResponseMessage(
std::shared_ptr<tensorpipe::Pipe>& pipe,
JitFuture& futureResponseMessage,
uint64_t messageId,
std::vector<c10::Stream> streams) {
if (!rpcAgentRunning_.load()) {
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
<< " won't send response to request #" << messageId << " to "
<< pipe->getRemoteName() << ", as the agent is shutting down";
return;
}
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " is sending response to request #" << messageId << " to "
<< pipe->getRemoteName();
if (!futureResponseMessage.hasError()) {
c10::intrusive_ptr<Message> responseMessage =
futureResponseMessage.value().toCustomClass<Message>();
responseMessage->setId(static_cast<int64_t>(messageId));
std::vector<c10::Device> devices;
try {
devices = getDevicesForRemote(pipe->getRemoteName(), *responseMessage);
} catch (const std::exception& e) {
responseMessage =
createExceptionResponse(e.what(), static_cast<int64_t>(messageId));
}
for (const auto& tensor : responseMessage->tensors()) {
const auto device = tensor.device();
if (!device.is_cpu()) {
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
if (std::find(devices_.begin(), devices_.end(), device) ==
devices_.end()) {
std::ostringstream oss;
std::copy(
devices_.begin(),
devices_.end(),
std::ostream_iterator<c10::Device>(oss, ", "));
responseMessage = createExceptionResponse(
c10::str(
"RPC detected that a user-function output tensor on device ",
device,
". This device is not one of the input tensor devices: ",
oss.str(),
"which is not yet supported. Please file a feature request "
"issue in PyTorch GitHub repo."),
static_cast<int64_t>(messageId));
break;
}
}
}
pipeWrite(
pipe,
responseMessage,
std::move(devices),
std::move(streams),
[this, pipe, messageId](const tensorpipe::Error& error) {
if (error) {
LOG(WARNING)
<< "RPC agent for " << workerInfo_.name_
<< " encountered error when sending response to request #"
<< messageId << " to " << pipe->getRemoteName() << ": "
<< error.what();
return;
}
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " done sending response to request #" << messageId
<< " to " << pipe->getRemoteName();
});
} else {
pipeWrite(
pipe,
createExceptionResponse(
futureResponseMessage.tryRetrieveErrorMessage(),
static_cast<int64_t>(messageId)),
/* devices */ {},
std::move(streams),
[this, pipe, messageId](const tensorpipe::Error& error) {
if (error) {
LOG(WARNING)
<< "RPC agent for " << workerInfo_.name_
<< " encountered error when sending response to request #"
<< messageId << " to " << pipe->getRemoteName() << ": "
<< error.what();
return;
}
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " done sending response to request #" << messageId
<< " to " << pipe->getRemoteName();
});
}
}
void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
pipeRead(
pipe,
[this, pipe](
const tensorpipe::Error& error,
c10::intrusive_ptr<Message> requestMessage,
std::vector<c10::Stream> streams) mutable {
if (error) {
if (shuttingDown_) {
// This is expected.
} else {
LOG(WARNING)
<< "RPC agent for " << workerInfo_.name_
<< " encountered error when reading incoming request from "
<< pipe->getRemoteName() << ": " << error.what();
}
return;
}
// Arm for next read
respond(pipe);
uint64_t messageId = requestMessage->id();
increaseCallCount(serverActiveCalls_);
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " received request #" << messageId << " from "
<< pipe->getRemoteName();
// Defer user RPC UDF run to thread pool
threadPool_.run([this,
pipe,
messageId,
requestMessage{std::move(requestMessage)},
streams{std::move(streams)}]() mutable {
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " is running request #" << messageId << " from "
<< pipe->getRemoteName() << " in thread pool";
c10::intrusive_ptr<JitFuture> futureResponseMessage;
try {
// Instead of creating a MultiStreamGuard here, the ctx is passed
// to the callback and the MultiStreamGuard is created there,
// because subsequent processing can switch threads due to 1)
// waiting for RRef arguments to become ready 2) async_execution.
// Besides, the `ctx` also needs to be propagated to
// `process***Call` methods to synchronize CUDA streams there
// to make sure that we fetch the correct value from `to_here()`
// call.
futureResponseMessage =
cb_->operator()(*requestMessage, std::move(streams));
} catch (const std::exception& /* unused */) {
futureResponseMessage =
c10::make_intrusive<JitFuture>(at::AnyClassType::get());
futureResponseMessage->setError(std::current_exception());
}
increaseCallCount(serverActiveAsyncCalls_);
futureResponseMessage->addCallback(
[this, pipe, messageId](
JitFuture& futureResponseMessage) mutable {
decreaseCallCount(serverActiveCalls_);
decreaseCallCount(serverActiveAsyncCalls_);
auto streams = getCurrentStreamsForDevices(
futureResponseMessage.devices());
sendCompletedResponseMessage(
pipe, futureResponseMessage, messageId, std::move(streams));
});
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " done running request #" << messageId << " from "
<< pipe->getRemoteName() << " in thread pool";
});
});
}
c10::intrusive_ptr<JitFuture> TensorPipeAgent::send(
const WorkerInfo& toWorkerInfo,
c10::intrusive_ptr<Message> requestMessage,
const float rpcTimeoutSeconds,
const DeviceMap& deviceMap) {
TORCH_CHECK(
requestMessage->isRequest(),
"TensorPipeAgent::send(..) is only for sending requests.");
if (!rpcAgentRunning_.load()) {
auto err = c10::str(
"Node ",
RpcAgent::getWorkerInfo().id_,
"tried to send() a message of type ",
requestMessage->type(),
" but RPC is no longer running on this node.");
TORCH_CHECK(false, err);
}
const auto& url = findWorkerURL(toWorkerInfo);
decltype(connectedPipes_)::iterator it;
{
std::unique_lock<std::mutex> lock(connectedPipesMutex_);
// See if we already have a connection to this address or not
it = connectedPipes_.find(toWorkerInfo.id_);
if (it == connectedPipes_.end()) {
// An instance of ClientPipe cannot be copied or moved as it contains a
// mutex, and to force in-place construction in GCC 5 we need piecewise
// construction in order to work around an issue.
it = connectedPipes_
.emplace(
std::piecewise_construct,
std::forward_as_tuple(toWorkerInfo.id_),
std::forward_as_tuple(context_->connect(
url,
tensorpipe::PipeOptions().remoteName(
toWorkerInfo.name_))))
.first;
}
}
ClientPipe& clientPipe = it->second;
std::shared_ptr<torch::distributed::rpc::TensorPipeAgent::AtomicJitFuture>
futureResponseMessage;
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
futureResponseMessage = std::make_shared<AtomicJitFuture>(devices_);
}
uint64_t messageId = nextMessageID_++;
requestMessage->setId(static_cast<int64_t>(messageId));
{
std::unique_lock<std::mutex> lock(clientPipe.mutex_);
clientPipe.pendingResponseMessage_[messageId] = futureResponseMessage;
}
// Get devices for tensors in the request message. This can throw if device
// maps are not configured properly for this request.
std::vector<c10::Device> devices;
if (deviceMap.empty()) {
devices =
getDevicesForRemote(clientPipe.pipe_->getRemoteName(), *requestMessage);
} else {
// If deviceMap is specified, use that instead.
devices = getDevicesForTensors(
requestMessage->tensors(),
deviceMap,
clientPipe.pipe_->getRemoteName());
}
futureResponseMessage->jitFuture->addCallback(
[this](JitFuture& /* unused */) {
TORCH_INTERNAL_ASSERT(
this->threadPool_.inThreadPool(),
"Future marked complete from outside the thread pool");
});
increaseCallCount(clientActiveCalls_);
// Use the default RPC timeout if no timeout is specified for this send call
auto timeout = rpcTimeoutSeconds == kUnsetRpcTimeout
? getRpcTimeout()
: std::chrono::milliseconds(
static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
// We only add to the timeoutMap_ if the timeout is not 0. Per our
// documentation, a user-provided timeout of 0 indicates the RPC should never
// expire (infinite timeout), so there is no need to track it in the
// timeoutMap_.
steady_clock_time_point expirationTime;
if (timeout.count() != 0) {
// Compute the expiration time for this message based on the timeout
expirationTime = computeRpcMessageExpiryTime(timeout);
// Add the Future to the right vector in the timeoutMap_
{
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
auto& timeoutFuturesVector = timeoutMap_[expirationTime];
messageIdToTimeout_.emplace(messageId, expirationTime);
timeoutFuturesVector.emplace_back(
messageId, futureResponseMessage, timeout);
}
timeoutThreadCV_.notify_one();
}
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #"
<< messageId << " to " << clientPipe.pipe_->getRemoteName();
std::vector<c10::Stream> streams;
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
streams = getStreamsFromPoolForDevices(devices_);
}
makeStreamsWaitOnOthers(
streams,
getCurrentStreamsForDevices(
getDevicesOfTensors(requestMessage->tensors())));
pipeWrite(
clientPipe.pipe_,
requestMessage,
std::move(devices),
std::move(streams),
[this, &clientPipe, messageId](const tensorpipe::Error& error) mutable {
if (error) {
if (error.isOfType<tensorpipe::PipeClosedError>() &&
!rpcAgentRunning_.load()) {
// This is expected.
} else {
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
<< " encountered error when sending outgoing request #"
<< messageId << " to "
<< clientPipe.pipe_->getRemoteName() << ": "
<< error.what();
}
handleClientError(clientPipe, error);
return;
}
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " sent request #"
<< messageId << " to " << clientPipe.pipe_->getRemoteName();
pipeRead(
clientPipe.pipe_,
[this, &clientPipe](
const tensorpipe::Error& error,
c10::intrusive_ptr<Message> responseMessage,
std::vector<c10::Stream> streams) {
if (error) {
if (error.isOfType<tensorpipe::PipeClosedError>() &&
!rpcAgentRunning_.load()) {
// This is expected.
} else {
LOG(WARNING)
<< "RPC agent for " << workerInfo_.name_
<< " encountered error when reading incoming response from "
<< clientPipe.pipe_->getRemoteName() << ": "
<< error.what();
}
handleClientError(clientPipe, error);
return;
}
// Identify future response message by message ID
uint64_t messageId = responseMessage->id();
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " received response #" << messageId << " from "
<< clientPipe.pipe_->getRemoteName();
std::shared_ptr<AtomicJitFuture> futureResponseMessage;
{
std::lock_guard<std::mutex> lock(clientPipe.mutex_);
// A read error will lead all following callbacks to be
// invoked with error, and shouldn't reach here.
TORCH_INTERNAL_ASSERT(
!clientPipe.inError_, "Shouldn't be in error state");
auto it = clientPipe.pendingResponseMessage_.find(messageId);
TORCH_INTERNAL_ASSERT(
it != clientPipe.pendingResponseMessage_.end(),
"message ID ",
messageId,
" is not recognized");
futureResponseMessage = std::move(it->second);
clientPipe.pendingResponseMessage_.erase(it);
}
// Remove entry from timeoutMap_.
removeFromTimeoutMap(messageId);
if (responseMessage->type() == MessageType::EXCEPTION) {
markFutureWithError(
std::move(futureResponseMessage),
std::string(
responseMessage->payload().begin(),
responseMessage->payload().end()));
} else {
markFutureAsComplete(
std::move(futureResponseMessage),
std::move(responseMessage),
std::move(streams));
}
});
});
return futureResponseMessage->jitFuture;
}
void TensorPipeAgent::handleClientError(
ClientPipe& clientPipe,
const tensorpipe::Error& error) {
// When an error occurs on a pipe all pending operations will be aborted and
// all callbacks invoked with error, hence we immediately flush all future
// messages belonging to this pipe.
decltype(clientPipe.pendingResponseMessage_) pendingMsgs;
{
std::lock_guard<std::mutex> lock(clientPipe.mutex_);
std::swap(clientPipe.pendingResponseMessage_, pendingMsgs);
clientPipe.inError_ = true;
}
std::string errorMsg = error.what();
for (auto& p : pendingMsgs) {
markFutureWithError(std::move(p.second), errorMsg);
// Remove entry from timeoutMap_.
removeFromTimeoutMap(p.first);
}
}
void TensorPipeAgent::pollTimeoutRpcs() {
while (rpcAgentRunning_.load()) {
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
// We sleep until the earliest expiring RPC in the timeoutMap_. We must
// also ensure that we sleep while the map is empty, and we exit sleeping
// if the RPC Agent has been shutdown.
for (;;) {
if (!rpcAgentRunning_.load()) {
return;
}
if (!timeoutMap_.empty()) {
steady_clock_time_point earliestTimeout = timeoutMap_.begin()->first;
if (std::chrono::steady_clock::now() >= earliestTimeout) {
break;
}
timeoutThreadCV_.wait_until(lock, earliestTimeout);
} else {
timeoutThreadCV_.wait(lock);
}
}
// Move all these futures to a separate vector so we can process them
// outside the lock.
std::vector<TimeoutMessageMetadata> timedOutFutures =
std::move(timeoutMap_.begin()->second);
// We can safely remove this key from the timeoutMap_ since all these
// futures will be processed.
timeoutMap_.erase(timeoutMap_.begin());
for (auto& timeoutMetadata : timedOutFutures) {
// Remove from messageIdToTimeout map.
messageIdToTimeout_.erase(timeoutMetadata.messageId);
}
lock.unlock();
// Set an error on futures added to the timedOutFutures vector. We do this
// outside the lock to prevent potential lock-order-inversions by callbacks
// triggered by the setError call.
for (auto& timeoutMetadata : timedOutFutures) {
std::string errorMsg =
fmt::format(kRpcTimeoutErrorStr, timeoutMetadata.timeout.count());
auto err = makeRPCError(errorMsg, RPCErrorType::TIMEOUT);
markFutureWithError(
std::move(timeoutMetadata.responseFuture), std::move(err));
}
}
}
void TensorPipeAgent::leaveGroup() {
std::unique_lock<std::mutex> lock(callCountMutex_);
// local worker ActiveCallCount is 0 at this point and we will shutdown
// (any future calls will be dropped)
callCountCV_.wait(lock, [this] { return clientActiveCalls_ == 0; });
// Remove this agent's WorkerInfo from store
removeCurrentName(rankToNameStore_, workerInfo_.id_, workerInfo_.name_);
// Set internal variable to be used during destructor
shuttingDown_ = true;
}
// TODO: Remove join()
void TensorPipeAgent::join(bool shutdown, float /* unused */) {
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is joining";
if (!isStaticGroup_) {
leaveGroup();
return;
}
// This method behaves like a barrier, as it can only return once all workers
// have no more requests pending, including "nested" requests (triggered from
// within the remote code of another call) and "follow-up" requests (triggered
// from the callback of a future).
while (true) {
{
std::unique_lock<std::mutex> lock(callCountMutex_);
// It is enough to wait for there to be no more active client calls, since
// each server call corresponds to a client call for some other worker.
callCountCV_.wait(lock, [this] { return clientActiveCalls_ == 0; });
// We'd like to immediately proceed with the allreduce, but it's a call
// that may block for some time, as it waits for other workers to also
// complete all their active client calls. While we call allreduce we must
// hold the mutex, or else the count we send to other workers may get
// stale (e.g., if some nested call happens in the meantime). But we can't
// hold the lock for an indeterminately long time, as that would block
// other operations (e.g., send). Thus we must release the lock and only
// re-acquire it when all workers are ready to proceed with the allreduce.
// We perform this synchronization using a barrier.
}
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " completed all client calls and is entering a barrier";
syncCallCount(shutdownStore_, worldSize_);
{
std::unique_lock<std::mutex> lock(callCountMutex_);
// At this point, the count may have become non-zero again. We can't wait
// for those calls to complete as other workers are waiting for us in the
// allreduce and we would block them. Thus we send our count even if it is
// non-zero and if anyone (be it us or another worker) has a non-zero
// count we'll just do another round.
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " exited the barrier and found " << clientActiveCalls_
<< " active client calls";
int totalClientActiveCalls =
syncCallCount(shutdownStore_, worldSize_, clientActiveCalls_);
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " completed sync call counts and got a total of "
<< totalClientActiveCalls
<< " active client calls across all workers";
if (totalClientActiveCalls == 0) {
if (shutdown) {
shuttingDown_ = true;
syncCallCount(shutdownStore_, worldSize_);
}
break;
}
}
}
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " done joining";
}
void TensorPipeAgent::shutdownImpl() {
// FIXME Isn't it too verbose for a library to print logs in normal operation?
LOG(INFO) << "RPC agent for " << workerInfo_.name_ << " is shutting down";
// Join the Timeout Thread
timeoutThreadCV_.notify_one();
if (timeoutThread_.joinable()) {
timeoutThread_.join();
}
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " done waiting for timeout thread to join";
// This will close all the pipes and listeners, invoke all callbacks with
// errors, turn down the I/O event loops and wait for everything to terminate.
context_->join();
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " done waiting for TensorPipe context to join";
// NOTE: We need to call waitWorkComplete in the end after we have shutdown
// all listeners for Tensorpipe. This is to drain any already accepted work
// in the ThreadPool. If this is done before we shutdown the listeners,
// additional work could be added after this call and before we shutdown
// listeners. This work would continue executing in the threadpool and might
// cause issues during shutdown of the system.
threadPool_.waitWorkComplete();
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " done waiting for thread pool to complete work";
}
const WorkerInfo& TensorPipeAgent::getWorkerInfo(
const std::string& workerName) const {
std::unordered_map<std::string, WorkerInfo>::const_iterator it;
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
it = workerNameToInfo_.find(workerName);
}
TORCH_CHECK(
it != workerNameToInfo_.end(),
fmt::format(
"name:{},rank:{} could not find destination name {}",
workerInfo_.name_,
workerInfo_.id_,
workerName));
return it->second;
}
const WorkerInfo& TensorPipeAgent::getWorkerInfo(worker_id_t workerId) const {
std::unordered_map<worker_id_t, WorkerInfo>::const_iterator it;
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
it = workerIdToInfo_.find(workerId);
}
TORCH_CHECK(
it != workerIdToInfo_.end(),
fmt::format(
"name:{},rank:{} could not find destination id {}",
workerInfo_.name_,
workerInfo_.id_,
workerId));
return it->second;
}
std::vector<WorkerInfo> TensorPipeAgent::getWorkerInfos() const {
std::vector<WorkerInfo> workerInfos;
workerInfos.reserve(workerNameToInfo_.size());
for (auto& item : workerNameToInfo_) {
workerInfos.emplace_back(item.second);
}
return workerInfos;
}
const std::string& TensorPipeAgent::findWorkerURL(
const WorkerInfo& worker) const {
std::unordered_map<std::string, std::string>::const_iterator it;
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
it = workerNameToURL_.find(worker.name_);
}
TORCH_CHECK(
it != workerNameToURL_.end(),
fmt::format(
"name:{},rank:{} could not find destination url for name {}",
workerInfo_.name_,
workerInfo_.id_,
worker.name_));
return it->second;
}
void TensorPipeAgent::updateGroupMembership(
const WorkerInfo& workerInfo,
const std::vector<c10::Device>& devices,
const std::unordered_map<std::string, DeviceMap>& reverseDeviceMaps,
bool isJoin) {
std::string name = workerInfo.name_;
worker_id_t id = workerInfo.id_;
// Rank with workerInfo is joining the group, update internal mappings
if (isJoin) {
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
workerIdToInfo_.emplace(id, workerInfo);
workerNameToInfo_.emplace(name, workerInfo);
// TODO: we should get nodeAddrStr in the joining process, then pass in as
// an argument rather than getting from store each time
auto nodeAddrData = nameToAddressStore_.get(name);
auto nodeAddrStr = std::string(
reinterpret_cast<const char*>(nodeAddrData.data()),
nodeAddrData.size());
workerNameToURL_.insert({name, nodeAddrStr});
for (const auto& it : reverseDeviceMaps) {
if (reverseDeviceMaps_.find(it.first) == reverseDeviceMaps_.end()) {
reverseDeviceMaps_[it.first] = it.second;
}
}
// TODO: clean up mutex for devices_ usage
// Add devices that have not been added yet
for (const auto& it : devices) {
if (std::find(devices_.begin(), devices_.end(), it) == devices_.end()) {
devices_.push_back(it);
}
}
} else {
workerIdToInfo_.erase(id);
workerNameToInfo_.erase(name);
workerNameToURL_.erase(name);
// remove reverse device maps that are no longer used
for (auto it = reverseDeviceMaps_.begin();
it != reverseDeviceMaps_.end();) {
if (reverseDeviceMaps.find(it->first) == reverseDeviceMaps.end()) {
it = reverseDeviceMaps_.erase(it);
} else {
it++;
}
}
// remove devices that are no longer used
for (auto it = devices_.begin(); it != devices_.end();) {
if (std::find(devices.begin(), devices.end(), *it) == devices.end()) {
it = devices_.erase(it);
} else {
it++;
}
}
}
}
std::unordered_map<std::string, std::string> TensorPipeAgent::getMetrics() {
std::unordered_map<std::string, std::string> metrics;
metrics[kThreadPoolSize] = std::to_string(threadPool_.size());
metrics[kNumIdleThreads] = std::to_string(threadPool_.numAvailable());
{
std::unique_lock<std::mutex> lock(callCountMutex_);
metrics[kClientActiveCalls] = std::to_string(clientActiveCalls_);
metrics[kServerActiveCalls] = std::to_string(serverActiveCalls_);
metrics[kServerActiveAsyncCalls] = std::to_string(serverActiveAsyncCalls_);
}
if (isGILProfilingEnabled()) {
{
std::unique_lock<std::mutex> lock(metricsMutex_);
// Include the averages for each time series metric. This is just the GIL
// Wait Time for now.
auto averageGilWaitTime =
timeSeriesMetrics_[kGilAverageWaitTime].computeAverage();
lock.unlock();
metrics[kGilAverageWaitTime] = std::to_string(averageGilWaitTime);
}
}
return metrics;
}
void TensorPipeAgent::addGilWaitTime(
const std::chrono::microseconds gilWaitTime) {
std::lock_guard<std::mutex> lock(metricsMutex_);
timeSeriesMetrics_[kGilAverageWaitTime].addData(gilWaitTime.count());
}
TensorPipeAgent::NetworkDataDict TensorPipeAgent::getNetworkData() {
std::lock_guard<std::mutex> lock(networkDataMutex_);
return networkData_;
}
NetworkSourceInfo TensorPipeAgent::getNetworkSourceInfo() {
NetworkSourceInfo info = {
RpcAgent::getWorkerInfo().id_,
nameToAddressStore_.get(RpcAgent::getWorkerInfo().name_)};
return info;
}
void TensorPipeAgent::trackNetworkData(
uint64_t requestSize,
uint64_t responseSize,
const std::string& destWorkerName) {
std::lock_guard<std::mutex> lock(networkDataMutex_);
networkData_[destWorkerName].numCalls++;
networkData_[destWorkerName].totalSentBytes += requestSize;
networkData_[destWorkerName].totalRecvBytes += responseSize;
}
void TensorPipeAgent::trackNetworkError(
uint64_t requestSize,
const std::string& destWorkerName) {
std::lock_guard<std::mutex> lock(networkDataMutex_);
networkData_[destWorkerName].numCalls++;
networkData_[destWorkerName].totalSentBytes += requestSize;
networkData_[destWorkerName].totalErrors++;
}
void TensorPipeAgent::increaseCallCount(int32_t& count) {
{
std::unique_lock<std::mutex> lock(callCountMutex_);
++count;
}
callCountCV_.notify_all();
}
void TensorPipeAgent::decreaseCallCount(int32_t& count) {
{
std::unique_lock<std::mutex> lock(callCountMutex_);
--count;
}
callCountCV_.notify_all();
}
void TensorPipeAgent::markFutureAsComplete(
std::shared_ptr<AtomicJitFuture> atomicFuture,
c10::intrusive_ptr<Message> message,
std::vector<c10::Stream> streams) {
if (!atomicFuture->isComplete.test_and_set()) {
// Completing the future will run its callbacks, which could execute
// arbitrary user code. To prevent blocking or stalling the TensorPipe event
// loops, we defer this to a worker thread.
threadPool_.run([this,
atomicFuture{std::move(atomicFuture)},
message{std::move(message)},
streams{std::move(streams)}]() mutable {
c10::MultiStreamGuard guard(streams);
std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages =
message->getStorages();
atomicFuture->jitFuture->markCompleted(
std::move(message), std::move(storages));
// The future's callbacks may schedule further RPCs, increasing the count.
// Thus we must decrease it after completing the future, otherwise it may
// briefly dip to zero and trick join into thinking all work is done.
decreaseCallCount(clientActiveCalls_);
});
}
}
void TensorPipeAgent::markFutureWithError(
std::shared_ptr<AtomicJitFuture> atomicFuture,
std::string errorMsg) {
if (!atomicFuture->isComplete.test_and_set()) {
// Completing the future will run its callbacks, which could execute
// arbitrary user code. To prevent blocking or stalling the TensorPipe event
// loops, we defer this to a worker thread.
threadPool_.run([this,
atomicFuture{std::move(atomicFuture)},
errorMsg{std::move(errorMsg)}]() mutable {
atomicFuture->jitFuture->setError(
std::make_exception_ptr(std::runtime_error(errorMsg)));
// The future's callbacks may schedule further RPCs, increasing the count.
// Thus we must decrease it after completing the future, otherwise it may
// briefly dip to zero and trick join into thinking all work is done.
decreaseCallCount(clientActiveCalls_);
});
}
}
std::vector<c10::Device> TensorPipeAgent::getDevicesForRemote(
const std::string& remoteName,
const Message& message) const {
std::unordered_map<std::string, DeviceMap> deviceMaps;
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
deviceMaps = message.isRequest() ? opts_.deviceMaps : reverseDeviceMaps_;
}
const auto errStr = c10::str(
"TensorPipe RPC backend only supports CPU tensors by default, please "
"move your tensors to CPU before sending them over RPC, or call "
"`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly "
"configure device mapping. ",
message.isRequest() ? "Request" : "Response",
" device mapping is not available for destination ",
remoteName);
const auto& iter = deviceMaps.find(remoteName);
if (iter == deviceMaps.end()) {
for (const auto& t : message.tensors()) {
TORCH_CHECK(
t.device().is_cpu(),
errStr,
", but found tensor on device: ",
t.device());
}
return {};
} else {
return getDevicesForTensors(message.tensors(), iter->second, errStr);
}
}
DeviceMap TensorPipeAgent::getDeviceMap(const WorkerInfo& dst) const {
auto it = opts_.deviceMaps.find(dst.name_);
if (it == opts_.deviceMaps.end()) {
return {};
}
return it->second;
}
const c10::intrusive_ptr<::c10d::Store> TensorPipeAgent::getStore() const {
return store_;
}
TensorPipeRpcBackendOptions TensorPipeAgent::getBackendOptions() const {
return opts_;
}
const std::vector<c10::Device>& TensorPipeAgent::getDevices() const {
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
return devices_;
}
size_t TensorPipeAgent::timeoutMapSize() {
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
return timeoutMap_.size();
}
size_t TensorPipeAgent::numPendingResponses() {
std::unique_lock<std::mutex> lock(callCountMutex_);
return clientActiveCalls_;
}
size_t TensorPipeAgent::messageIdToTimeoutMapSize() {
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
return messageIdToTimeout_.size();
}
} // namespace torch::distributed::rpc
#endif // USE_TENSORPIPE