Files
pytorch/torch/csrc/distributed/rpc/process_group_agent.cpp
Shen Li 2486b0ba82 Add Python RRef as args and return value (#25499)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25499

See #23110 for model parallel design details, and #26759 for the RRef
protocol. This commit add support for using RRef as Python UDF arguments
and return value. RRefs can now be shared from owner to user, from user to
owner, or from user to user.

Limitations:
1. No implicit type conversion yet. (#27099)
2. No failure handling and retry. (#26116)
3. UDF is not yet blocked until all RRefs are confirmed. (#27098)
4. Internal RRef control messages are not idempotent yet. (#26116)
5. Cannot delete RRefs correctly when there are circular dependencies. (#27096)

Main changes:

1. Added `SCRIPT_REMOTE_CALL` and `PYTHON_REMOTE_CALL` to `Message.h` to represent `dist.remote` invocations.
2. Added `SCRIPT_RREF_FETCH_CALL`, `PYTHON_RREF_FETCH_CALL`, `RREF_USER_ACCEPT`, `RREF_USER_DELETE`, `RREF_CHILD_ACCEPT`, and `RREF_FORK_REQUEST` to `Message.h` as internal RRef control messages.
3. New message request handling code is added to `functions.cpp`, and message format is added in `script_remote_call.h`, `python_remote_call.h`, and `rref_proto.h`.
4. Added a `PyRRef` type in `py_rref.h` and `py_rref.cpp` which holds a shared pointer to C++ `RRef` type. `PyRRef` wraps the C++ API and also implements RRef pickling and unpickling. RRef fork related control messages will be sent during RRef pickling/unpickling procedure.
5.  Update `RRef.h` and `RRef.cpp` accordingly to support `py::object` RRefs.
6. RRef context (reference count, etc.) are tracked in `rref_context.h` and `rref_context.cpp`.

Test Plan:
Imported from OSS

buck test mode/dev-nosan //caffe2/test:rpc_fork

Differential Revision: D17184146

Pulled By: mrshenli

fbshipit-source-id: a3a268efc087ac1ef489136ab957080382629265
2019-10-03 17:47:12 -07:00

404 lines
14 KiB
C++

#include <torch/csrc/distributed/rpc/process_group_agent.h>
#include <c10/util/C++17.h>
#include <c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/rpc/request_callback_impl.h>
#include <Python.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace {
// Write the message into the given ostream
void serialize(const Message& message, std::ostream& os) {
// We cast const void* to void* here because we need to create a tensor using
// that memory space. If is fine as that tensor stays function-local, and will
// not be modified during its lifetime.
auto payload = const_cast<void*>( // NOLINT
static_cast<const void*>(message.payload().data()));
auto payload_size = message.payload().size();
// getting tensor table from the message
std::vector<torch::Tensor> tensors = message.tensors();
// append payload as a tensor
tensors.push_back(torch::from_blob(payload, payload_size, {torch::kChar}));
// append id and autograd metadata as a tensor
tensors.push_back(torch::tensor({message.id()}, {torch::kInt64}));
torch::save(tensors, os);
}
Message deserialize(MessageType type, std::istream& is) {
std::vector<torch::Tensor> tensors;
torch::load(tensors, is);
TORCH_CHECK(tensors.size() >= 2, "Failed to deserialize a message.");
auto idTensor = std::move(tensors.back());
tensors.pop_back();
auto payloadTensor = std::move(tensors.back());
tensors.pop_back();
TORCH_INTERNAL_ASSERT(1, idTensor.numel());
int64_t id = idTensor.storage().data<int64_t>()[0];
std::vector<char> payload(payloadTensor.numel());
if (payloadTensor.numel() > 0) {
std::memcpy(
payload.data(), payloadTensor.storage().data(), payloadTensor.numel());
}
return Message(std::move(payload), std::move(tensors), type, id);
}
} // namespace
////////////////////////// MessageCounter /////////////////////////////////
ProcessGroupAgent::MessageCounter::MessageCounter(int worldSize)
: counters_(worldSize) {}
void ProcessGroupAgent::MessageCounter::increment(int dst) {
std::lock_guard<std::mutex> guard(mutex_);
++counters_[dst];
}
std::vector<int64_t> ProcessGroupAgent::MessageCounter::snapshot() {
std::lock_guard<std::mutex> guard(mutex_);
return counters_;
}
//////////////////////// ProcessGroupAgent /////////////////////////////////
void ProcessGroupAgent::collectNames() {
const std::string& workerName = workerInfo_.name_;
const auto worldSize = pg_->getSize();
// use c10d allgather to collect names
torch::Tensor nameTensor =
torch::zeros({WorkerInfo::MAX_NAME_LEN}, torch::kChar);
memcpy(nameTensor.storage().data(), workerName.c_str(), workerName.length());
std::vector<torch::Tensor> inputName = {nameTensor};
std::vector<std::vector<torch::Tensor>> outputNames(1);
for (int i = 0; i < worldSize; ++i) {
outputNames[0].emplace_back(
torch::empty({WorkerInfo::MAX_NAME_LEN}, {torch::kChar}));
}
pg_->allgather(outputNames, inputName)->wait();
// convert collected name tensors into string names
for (int i = 0; i < worldSize; ++i) {
torch::Tensor& tensor = outputNames[0][i];
std::string peerName((const char*)tensor.storage().data<signed char>());
TORCH_CHECK(
nameMap_.find(peerName) == nameMap_.end(),
"RpcAgent name ",
peerName,
" is not unique.");
nameMap_[std::move(peerName)] = i;
}
}
ProcessGroupAgent::ProcessGroupAgent(
std::string workerName,
std::shared_ptr<c10d::ProcessGroup> pg,
int numSendRecvThreads)
: RpcAgent(
WorkerInfo(std::move(workerName), pg->getRank()),
c10::guts::make_unique<RequestCallbackImpl>()),
pg_(std::move(pg)),
sendCounts_(pg_->getSize()),
recvCounts_(pg_->getSize()),
nextId_(0),
sendMutexes_(pg_->getSize()),
threadPool_(numSendRecvThreads) {
collectNames();
TORCH_CHECK(
nameMap_.size() > 1,
"ProcessGroupAgent requires world_size to "
"be at least 2, but got ",
nameMap_.size());
auto workerRankIter = nameMap_.find(workerInfo_.name_);
TORCH_CHECK(
workerRankIter != nameMap_.end(),
"Failed to resolve worker "
"name ",
workerInfo_.name_,
" to a ProcessGroup rank.");
TORCH_CHECK(
pg_->getRank() == workerRankIter->second,
"Resolved worker rank ",
workerRankIter->second,
" does not match ProcessGroup rank ",
pg_->getRank());
// tmp vector to sort names in rank's order
std::vector<std::string> tmpWorkerIds(pg_->getSize());
for (auto& entry : nameMap_) {
tmpWorkerIds[entry.second] = entry.first;
}
allWorkerInfo_.reserve(pg_->getSize());
for (int rank = 0; rank < (int)tmpWorkerIds.size(); ++rank) {
allWorkerInfo_.emplace_back(std::move(tmpWorkerIds[rank]), rank);
}
// construct PythonRpcHandler singleton here
PythonRpcHandler::getInstance();
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
}
const WorkerInfo& ProcessGroupAgent::getWorkerInfo(
const std::string& workerName) const {
const auto idIter = nameMap_.find(workerName);
TORCH_CHECK(
idIter != nameMap_.end(), "Unknown destination worker ", workerName);
return allWorkerInfo_[idIter->second];
}
const WorkerInfo& ProcessGroupAgent::getWorkerInfo(worker_id_t id) const {
return allWorkerInfo_[id];
}
void ProcessGroupAgent::join() {
// Every process i sends a SHUTDOWN message to process i + 1. This is
// necessary for now because:
// 1. There is no abort API for ProcessGroup::recvAnysource yet. We have to
// feed it a message or kill the thread.
// 2. A GLOO process cannot send message to itself. (there is an ongoing
// effort to fix this problem).
sync();
std::unique_lock<std::mutex> lock(futureMutex_);
futureCV_.wait(lock, [this] { return futures_.empty(); });
lock.unlock();
pg_->barrier()->wait();
int dst = (pg_->getRank() + 1) % pg_->getSize();
enqueueSend(
SendWork(allWorkerInfo_[dst], Message({}, {}, MessageType::SHUTDOWN)));
threadPool_.waitWorkComplete();
listenerThread_.join();
}
bool ProcessGroupAgent::hasPendingMessage() {
const auto worldSize = pg_->getSize();
std::vector<int64_t> snapshot;
snapshot.reserve(2 * worldSize);
auto recvSnapshot = recvCounts_.snapshot();
auto sendSnapshot = sendCounts_.snapshot();
snapshot.insert(
snapshot.end(),
std::make_move_iterator(recvSnapshot.begin()),
std::make_move_iterator(recvSnapshot.end()));
snapshot.insert(
snapshot.end(),
std::make_move_iterator(sendSnapshot.begin()),
std::make_move_iterator(sendSnapshot.end()));
std::vector<torch::Tensor> inputSnapshot = {
torch::from_blob(snapshot.data(), {2, worldSize}, {torch::kInt64})};
// allgather both send and recv messages in one shot
std::vector<std::vector<torch::Tensor>> outputSnapshots(1);
for (int i = 0; i < worldSize; ++i) {
outputSnapshots[0].emplace_back(
torch::zeros({2, worldSize}, {torch::kInt64}));
}
pg_->allgather(outputSnapshots, inputSnapshot)->wait();
// loop through all send/recv pairs to make sure that all sent messages are
// processed.
const auto& peerCounts = outputSnapshots[0];
for (int from = 0; from < worldSize; ++from) {
for (int to = 0; to < worldSize; ++to) {
// peerCounts[x][0] is recv counts, and peerCounts[x][1] is send counts
const auto& sentCnt = peerCounts[from][1][to].data_ptr<int64_t>()[0];
const auto& recvCnt = peerCounts[to][0][from].data_ptr<int64_t>()[0];
// NB: we cannot throw an error when sentCnt < recvCnt here. Because, send
// and recv counts on different workers are read in a distributed manner.
// It is possible that the sender reads its send count before sending, but
// the receive reads its recv count after receiving. Hence, both > and <
// are valid states.
if (sentCnt != recvCnt) {
return true;
}
}
}
return false;
}
void ProcessGroupAgent::sync() {
// Block until all processes wants to sync.
pg_->barrier()->wait();
// block until all peers agree that all sent messages have been processed.
do {
// Finish all send/recv tasks in the thread pool
threadPool_.waitWorkComplete();
// As there could be nested RPC calls, or response callback could also
// trigger more messages to be sent, we need to wait for the thread pool
// again.
} while (hasPendingMessage());
}
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
const WorkerInfo& to,
Message&& message) {
TORCH_CHECK(
to.id_ != (worker_id_t)pg_->getRank(),
"ProcessGroupAgent does not support making RPC calls to self.")
TORCH_CHECK(
to.id_ < (worker_id_t)pg_->getSize(),
"Destination rank is out of bound, got ",
to.id_,
", but world size is ",
pg_->getRank());
auto requestId = nextId();
auto future = std::make_shared<FutureMessage>();
if (message.isRequest()) {
{
std::lock_guard<std::mutex> lock{futureMutex_};
futures_[requestId] = future;
}
message.setId(requestId);
} else {
future->markCompleted();
}
// NB: cannot directly pass ``to`` to the ``SendWork``, because it might no
// longer be alive when the ``SendWork`` is executed. For example, the
// application could query the ``WorkerInfo`` using name through the
// ``RpcAgent::getWorkerInfo`` API, and pass the ``WorkerInfo`` back here, so
// we have C++ -> Python -> C++. For an asynchronous RPC, the ``WorkerInfo``
// reference on Python side could die before ``SendWork`` uses it, and Pybind
// will not keep the Python reference alive even if it originally comes from
// the C++ land. Hence, we have to explicitly use the ``WorkerInfo`` in the
// C++ land.
enqueueSend(SendWork(allWorkerInfo_[to.id_], std::move(message)));
return future;
}
void ProcessGroupAgent::enqueueSend(SendWork work) {
// NB: this can be changed to use a native move capture when moved to C++14
threadPool_.run(std::bind(
[&](const SendWork& work) {
std::stringstream ss;
serialize(work.message_, ss);
std::string serializedPayload = ss.str();
std::vector<torch::Tensor> preamble = {torch::tensor(
{(int64_t)pg_->getRank(),
(int64_t)serializedPayload.length(),
(int64_t)work.message_.type()},
{torch::kLong})};
// ProcessGroup is not thread-safe when sending with the same tag, hence
// the lock
std::vector<std::shared_ptr<c10d::ProcessGroup::Work>> pendingSends;
const auto& dst = work.to_.id_;
if (work.message_.isShutdown()) {
pendingSends.reserve(1);
{
std::lock_guard<std::mutex> guard(sendMutexes_[dst]);
pendingSends.emplace_back(
pg_->send(preamble, dst, dst /* channelTag */));
}
} else {
std::vector<torch::Tensor> payload = {torch::from_blob(
(void*)serializedPayload.c_str(),
serializedPayload.length(),
{torch::kChar})};
pendingSends.reserve(2);
sendCounts_.increment(dst);
{
std::lock_guard<std::mutex> guard(sendMutexes_[dst]);
pendingSends.emplace_back(
pg_->send(preamble, dst, dst /* channelTag */));
pendingSends.emplace_back(
pg_->send(payload, dst, dst /* channelTag */));
}
}
for (auto& pendingSend : pendingSends) {
pendingSend->wait();
}
},
std::move(work)));
}
void ProcessGroupAgent::enqueueRecv(RecvWork work) {
threadPool_.run(std::bind(
[&](RecvWork& work) {
torch::Tensor& payload = work.payload_;
std::stringstream ss(std::string(
(char*)payload.storage().data<signed char>(), payload.numel()));
Message message = deserialize(work.type_, ss);
if (message.isRequest()) {
send(work.from_, cb_->operator()(message));
} else if (message.isResponse()) {
auto id = message.id();
std::shared_ptr<FutureMessage> fm = nullptr;
{
std::lock_guard<std::mutex> lock{futureMutex_};
fm = futures_[id];
}
// Not holding lock on markCompleted as this could run callbacks that
// call agent_->send
fm->markCompleted(std::move(message));
{
std::lock_guard<std::mutex> lock{futureMutex_};
futures_.erase(id);
}
futureCV_.notify_all();
} else {
// TODO: pass the error back to the caller instead of crashing here.
TORCH_INTERNAL_ASSERT(
false, "unrecognized message type ", message.type());
}
recvCounts_.increment(work.from_.id_);
},
std::move(work)));
}
void ProcessGroupAgent::listenLoop() {
while (true) {
// rank, tensor size, message type
std::vector<torch::Tensor> preamble = {torch::empty({3}, {torch::kInt64})};
pg_->recvAnysource(preamble, pg_->getRank())->wait();
int64_t* preamble_items = preamble.front().storage().data<int64_t>();
auto srcRank = preamble_items[0];
auto size = preamble_items[1];
MessageType type = MessageType(preamble_items[2]);
if (type == MessageType::SHUTDOWN) {
// FIXME: This LOG also prints warnings no InitGoogleLogging() was invoked
// before logging, but it is not appropriate to call InitGoogleLogging()
// here either.
LOG(INFO) << "Shutting down ProcessGroupAgent " << workerInfo_.name_
<< std::endl;
return;
}
std::vector<torch::Tensor> tensors = {torch::empty({size}, {torch::kChar})};
pg_->recv(tensors, srcRank, pg_->getRank())->wait();
enqueueRecv(RecvWork(allWorkerInfo_[srcRank], type, std::move(tensors[0])));
}
}
} // namespace rpc
} // namespace distributed
} // namespace torch