sync and async torch.distributed.rpc for builtin operators (#23228)

Summary:
Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

* have a minimum working and testable RPC implementation
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.

We are not exporting this diff until we finalize distributed autograd design and publish the API review publicly.

https://fb.quip.com/FabTAZKVgQpf

Pull Request resolved: https://github.com/pytorch/pytorch/pull/23228
ghstack-source-id: 87816717

Reviewed By: zhaojuanmao

Differential Revision: D15194693

fbshipit-source-id: 7adb600796613cde6073db6c227451b89940ecaf
This commit is contained in:
Shen Li
2019-08-06 15:58:52 -07:00
committed by Facebook Github Bot
parent c07fc96b94
commit 8b349073ce
26 changed files with 1420 additions and 4 deletions

View File

@ -458,13 +458,17 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
)
if (NOT INTERN_BUILD_MOBILE)
list(APPEND TORCH_SRCS
list(APPEND TORCH_SRCS
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/FutureMessage.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/Message.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/ScriptCall.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/ScriptRet.cpp
${TORCH_SRC_DIR}/csrc/jit/export.cpp
${TORCH_ROOT}/test/cpp/jit/test.cpp
)
if (NOT WIN32)
list(APPEND TORCH_SRCS
list(APPEND TORCH_SRCS
${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp)
endif()
endif()

View File

@ -48,6 +48,7 @@ TESTS = [
'quantized',
'quantized_tensor',
'quantizer',
'rpc',
'sparse',
'torch',
'type_info',

114
test/test_rpc.py Normal file
View File

@ -0,0 +1,114 @@
import sys
import torch
import torch.distributed as dist
from common_distributed import MultiProcessTestCase
from common_utils import load_tests, run_tests
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
if not dist.is_available():
print('c10d not available, skipping tests')
sys.exit(0)
def _wrap_with_rpc(func):
def wrapper(self):
store = dist.FileStore(self.file.name, self.world_size)
dist.init_process_group(backend='gloo', rank=self.rank,
world_size=self.world_size, store=store)
dist.init_rpc('worker%d' % self.rank)
func(self)
dist.join_rpc()
return wrapper
class RpcTest(MultiProcessTestCase):
@property
def world_size(self):
return 4
@_wrap_with_rpc
def test_add(self):
n = self.rank + 1
dstRank = n % self.world_size
ret = dist.rpc('worker%d' % dstRank, torch.add,
args=(torch.ones(n, n), torch.ones(n, n)))
self.assertEqual(ret, torch.ones(n, n) * 2)
@_wrap_with_rpc
def test_scalar_add(self):
n = self.rank + 1
dstRank = n % self.world_size
ret = dist.rpc('worker%d' % dstRank, torch.add,
args=(torch.ones(n, n), n))
self.assertEqual(ret, (torch.ones(n, n) + n))
@_wrap_with_rpc
def test_async_add(self):
n = self.rank + 1
dstRank = n % self.world_size
fut = dist.rpc('worker%d' % dstRank,
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
async_call=True)
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
@_wrap_with_rpc
def test_nonzero(self):
n = self.rank + 1
dstRank = n % self.world_size
x = torch.ones(self.world_size, self.world_size)
x[self.rank][self.rank] = 0
ret = dist.rpc('worker%d' % dstRank, torch.nonzero, args=(x,))
self.assertEqual(ret, x.nonzero())
@_wrap_with_rpc
def test_multi_rpc(self):
dstRank = (self.rank + 1) % self.world_size
for i in range(20):
n = i + self.rank + 1
ret = dist.rpc('worker%d' % dstRank, torch.add,
args=(torch.ones(n, n), torch.ones(n, n)))
self.assertEqual(ret, torch.ones(n, n) * 2)
@_wrap_with_rpc
def test_sync_rpc(self):
dstRank = (self.rank + 1) % self.world_size
for i in range(20):
dist.sync_rpc()
n = i + self.rank + 1
ret1 = dist.rpc('worker%d' % dstRank, torch.add,
args=(torch.ones(n, n), torch.ones(n, n)))
dist.sync_rpc()
ret2 = dist.rpc('worker%d' % dstRank, torch.add,
args=(torch.ones(n, n), 2))
dist.sync_rpc()
self.assertEqual(ret1, torch.ones(n, n) * 2)
self.assertEqual(ret2, torch.ones(n, n) * 3)
@_wrap_with_rpc
def test_join_rpc(self):
n = self.rank + 1
dstRank = n % self.world_size
ret = dist.rpc('worker%d' % dstRank, torch.add,
args=(torch.ones(n, n), torch.ones(n, n)))
self.assertEqual(ret, torch.ones(n, n) * 2)
dist.join_rpc()
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
dist.rpc('worker%d' % dstRank, torch.add,
args=(torch.ones(n, n), torch.ones(n, n)))
# it's safe to call join_rpc() multiple times
dist.join_rpc()
if __name__ == '__main__':
run_tests()

View File

@ -49,6 +49,10 @@ libtorch_sources = [
"torch/csrc/autograd/record_function.cpp",
"torch/csrc/autograd/saved_variable.cpp",
"torch/csrc/autograd/variable.cpp",
"torch/csrc/distributed/rpc/FutureMessage.cpp",
"torch/csrc/distributed/rpc/Message.cpp",
"torch/csrc/distributed/rpc/ScriptCall.cpp",
"torch/csrc/distributed/rpc/ScriptRet.cpp",
"torch/csrc/Exceptions.cpp",
"torch/csrc/jit/autodiff.cpp",
"torch/csrc/jit/attributes.cpp",
@ -229,6 +233,11 @@ def add_torch_libs():
"torch/csrc/distributed/c10d/comm.cpp",
"torch/csrc/distributed/c10d/init.cpp",
"torch/csrc/distributed/c10d/reducer.cpp",
"torch/csrc/distributed/rpc/init.cpp",
"torch/csrc/distributed/rpc/RpcAgent.cpp",
"torch/csrc/distributed/rpc/ProcessGroupAgent.cpp",
"torch/csrc/distributed/rpc/functions.cpp",
"torch/csrc/distributed/rpc/python_functions.cpp",
"torch/csrc/jit/init.cpp",
"torch/csrc/jit/passes/inline_fork_wait.cpp",
"torch/csrc/jit/passes/onnx.cpp",

View File

@ -224,6 +224,12 @@ if (USE_DISTRIBUTED)
${TORCH_SRC_DIR}/csrc/distributed/c10d/comm.cpp
${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp
${TORCH_SRC_DIR}/csrc/distributed/c10d/reducer.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/init.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/RpcAgent.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/ProcessGroupAgent.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/functions.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_functions.cpp
)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES c10d)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D)

View File

@ -55,6 +55,7 @@
#ifdef USE_C10D
#include <torch/csrc/distributed/c10d/c10d.h>
#endif
#include <torch/csrc/distributed/rpc/rpc.h>
#endif
#define WITH_NUMPY_IMPORT_ARRAY
@ -626,6 +627,7 @@ PyObject* initModule() {
#ifdef USE_C10D
THPUtils_addPyMethodDefs(methods, torch::distributed::c10d::python_functions());
#endif
THPUtils_addPyMethodDefs(methods, torch::distributed::rpc::python_functions());
#endif
#if PY_MAJOR_VERSION == 2

View File

@ -0,0 +1,62 @@
#include <torch/csrc/distributed/rpc/FutureMessage.h>
namespace torch {
namespace distributed {
namespace rpc {
const Message& FutureMessage::wait() {
std::unique_lock<std::mutex> lock(mutex_);
finished_cv_.wait(lock, [this]{return completed_.load();});
return message_;
}
void FutureMessage::markCompleted(Message message) {
{
std::unique_lock<std::mutex> lock(mutex_);
TORCH_CHECK(!completed());
completed_ = true;
message_ = std::move(message);
fireCallbacks();
}
finished_cv_.notify_all();
}
void FutureMessage::markCompleted() {
markCompleted(Message());
}
const Message& FutureMessage::message() {
std::unique_lock<std::mutex> lock(mutex_);
TORCH_CHECK(completed(), "Cannot retrieve message before completed.");
return message_;
}
bool FutureMessage::completed() const {
return completed_;
}
void FutureMessage::addCallback(const FutureMessage::Callback& callback) {
std::unique_lock<std::mutex> lock(mutex_);
if (completed()) {
lock.unlock();
callback(message_);
return;
}
callbacks.push_back(callback);
}
void FutureMessage::fireCallbacks() {
TORCH_CHECK(completed(), "Firing callbacks on incomplete FutureMessage.");
// There is no need to protect callbacks with the lock.
// Once completed_ is set to true, no one can add new callback to the list.
for (auto& callback : callbacks) {
callback(message_);
}
callbacks.clear();
}
}
}
}

View File

@ -0,0 +1,43 @@
#pragma once
#include <torch/csrc/distributed/rpc/Message.h>
namespace torch {
namespace distributed {
namespace rpc {
// This class holds a message that will be ready in the future.
//
// TODO: consider using ivalue::Future.
struct TORCH_API FutureMessage final {
public:
using Callback = std::function<void(const Message&)>;
// TODO: add a get() API that returns immediately with an optional Message
// object.
const Message& wait();
void markCompleted(Message message);
void markCompleted();
const Message& message();
bool completed() const;
// If completed() the callback will be invoked in-place.
void addCallback(const Callback& callback);
private:
void fireCallbacks();
std::mutex mutex_;
std::atomic_bool completed_ {false}; // is this future complete
std::condition_variable finished_cv_;
std::vector<Callback> callbacks;
// TODO: make message_ an optional field, and get rid of UNKNOWN message type
Message message_;
};
}
}
}

View File

@ -0,0 +1,88 @@
#include <torch/csrc/distributed/rpc/Message.h>
namespace torch {
namespace distributed {
namespace rpc {
Message::Message() = default;
Message::Message(
std::vector<char>&& payload,
std::vector<torch::Tensor>&& tensors,
MessageType type)
: payload_(payload), tensors_(tensors), type_(type) {}
Message::Message(
std::vector<char>&& payload,
std::vector<torch::Tensor>&& tensors,
MessageType type,
int64_t id)
: payload_(payload), tensors_(tensors), type_(type), id_(id) {}
Message::Message(const Message& other) = default;
Message::Message(Message&& other) noexcept = default;
Message& Message::operator=(Message const& rhs) & {
auto payload = rhs.payload_;
auto tensors = rhs.tensors_;
Message(std::move(payload),
std::move(tensors),
rhs.type_,
rhs.id_).swap(*this);
return *this;
}
Message& Message::operator=(Message&& rhs) & {
Message(std::move(rhs.payload_),
std::move(rhs.tensors_),
rhs.type_,
rhs.id_).swap(*this);
return *this;
}
void Message::swap(Message& rhs) noexcept {
std::swap(payload_, rhs.payload_);
std::swap(tensors_, rhs.tensors_);
std::swap(type_, rhs.type_);
std::swap(id_, rhs.id_);
}
const std::vector<char>& Message::payload() const {
return payload_;
}
const std::vector<torch::Tensor>& Message::tensors() const {
return tensors_;
}
const MessageType& Message::type() const {
return type_;
}
bool Message::isRequest() const {
return MessageType::SCRIPT_CALL == type_
|| MessageType::PYTHON_CALL == type_;
}
bool Message::isResponse() const {
return MessageType::SCRIPT_RET == type_
|| MessageType::PYTHON_RET == type_;
}
bool Message::isShutdown() const {
return MessageType::SHUTDOWN == type_;
}
int64_t Message::id() const {
return id_;
}
void Message::setId(int64_t id) {
id_ = id;
}
}
}
}

View File

@ -0,0 +1,81 @@
#pragma once
#include <torch/serialize.h>
#include <vector>
namespace torch {
namespace distributed {
namespace rpc {
enum MessageType {
SCRIPT_CALL = 0,
SCRIPT_RET,
PYTHON_CALL,
PYTHON_RET,
SHUTDOWN,
UNKNOWN
};
// A message to be sent/received by an RpcAgent.
//
// A Message object contains 4 fields:
// payload (std::vector<char>): a binary chunk of data.
// tensors (std::vector<torch::Tensor>): all tensors. Tensor data are not
// included in the payload, and it is up to the RpcAgent implementation
// to determine how to serialize them. This design is helpful for
// communicating super large tensors where serializing all the data at
// once leads to excessively large memory footprint. An implementation
// can then serialize and send tensors chunck-by-chunk, in the streaming
// fashion.
// type (MessageType): type of the message.
// id (int64_t): message id, this is used by ProcessGroupAgent to match
// request and response. Other implementation can ignore it
// if they have their own ways to do matching.
//
// Layers above ``RpcAgent`` only converts ScriptCall, ScriptRet, PythonCall,
// and PythonRet into a Message, and it is up to the RpcAgent
// implementation to determine how to serialize a message.
class TORCH_API Message final {
public:
Message();
Message(std::vector<char>&& payload,
std::vector<torch::Tensor>&& tensors,
MessageType type);
Message(std::vector<char>&& payload,
std::vector<torch::Tensor>&& tensors,
MessageType type,
int64_t id);
Message(const Message& other);
Message(Message&& other) noexcept;
Message& operator=(Message const& rhs) &;
Message& operator=(Message&& rhs) &;
void swap(Message& rhs) noexcept;
const std::vector<char>& payload() const;
const std::vector<torch::Tensor>& tensors() const;
const MessageType& type() const;
bool isRequest() const;
bool isResponse() const;
bool isShutdown() const;
// id is an optional field to match request/response. If an RpcAgent
// implementation is able to do the matching without using this id, it can be
// dropped during message serialization.
int64_t id() const;
void setId(int64_t id);
private:
std::vector<char> payload_;
std::vector<torch::Tensor> tensors_;
MessageType type_ = MessageType::UNKNOWN;
int64_t id_ = -1;
};
} // rpc
} // distributed
} // torch

View File

@ -0,0 +1,227 @@
#include <torch/csrc/distributed/rpc/ProcessGroupAgent.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace {
// Write the message into the given ostream
void serialize(const Message& message, std::ostream& os) {
// We cast const void* to void* here because we need to create a tensor using
// that memory space. If is fine as that tensor stays function-local, and will
// not be modified during its lifetime.
auto payload = const_cast<void*>( // NOLINT
static_cast<const void*>(message.payload().data()));
auto payload_size = message.payload().size();
// getting tensor table from the message
std::vector<torch::Tensor> tensors = message.tensors();
// append payload as a tensor
tensors.push_back(torch::from_blob(payload, payload_size, {torch::kChar}));
// append id and type as a tensor
tensors.push_back(torch::tensor(
{message.id(), (int64_t) message.type()}, {torch::kInt64}
));
torch::save(tensors, os);
}
Message deserialize(std::istream& is) {
std::vector<torch::Tensor> tensors;
torch::load(tensors, is);
TORCH_CHECK(tensors.size() >= 2, "Failed to deserialize a message.");
auto miscTensor = std::move(tensors.back());
tensors.pop_back();
auto payloadTensor = std::move(tensors.back());
tensors.pop_back();
int64_t* miscItems = miscTensor.storage().data<int64_t>();
int64_t id = miscItems[0];
MessageType type = MessageType(miscItems[1]);
std::vector<char> payload(payloadTensor.numel());
if (payloadTensor.numel() > 0) {
std::memcpy(payload.data(),
payloadTensor.storage().data(),
payloadTensor.numel());
}
return Message(std::move(payload), std::move(tensors), type, id);
}
} // namespace
ProcessGroupAgent::ProcessGroupAgent(
std::string workerName,
std::unordered_map<std::string, int> nameMap,
std::shared_ptr<c10d::ProcessGroup> pg)
: RpcAgent(std::move(workerName), processRequestBlocking),
nameMap_(std::move(nameMap)),
stop_(false),
pg_(std::move(pg)),
nextId_(0) {
TORCH_CHECK(nameMap_.size() > 1, "ProcessGroupAgent requires world_size to "
"be at least 2, but got ", nameMap_.size());
auto workerRankIter = nameMap_.find(workerName_);
TORCH_CHECK(workerRankIter != nameMap_.end(), "Failed to resolve worker "
"name ", workerName_, " to a ProcessGroup rank.");
TORCH_CHECK(pg_->getRank() == workerRankIter -> second,
"Resolved worker rank ", workerRankIter -> second,
" does not match ProcessGroup rank ", pg_->getRank());
names_.resize(nameMap_.size());
for (auto& entry : nameMap_) {
names_[entry.second] = entry.first;
}
sendThread_ = std::thread(&ProcessGroupAgent::sendLoop, this);
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
}
void ProcessGroupAgent::join() {
// Every process i sends a SHUTDOWN message to process i + 1. This is
// necessary for now because:
// 1. There is no abort API for ProcessGroup::recvAnysource yet. We have to
// feed it a message or kill the thread.
// 2. A GLOO process cannot send message to itself. (there is an ongoing
// effort to fix this problem).
sync();
int dst = (pg_->getRank() + 1) % pg_->getSize();
enqueue(SendWork(dst, Message({}, {}, MessageType::SHUTDOWN)));
std::unique_lock<std::mutex> lock(sendQueueMutex_);
workConsumeCV_.wait(lock, [&] { return sendQueue_.empty(); });
stop_ = true;
lock.unlock();
workProduceCV_.notify_all();
sendThread_.join();
listenerThread_.join();
}
void ProcessGroupAgent::sync() {
// Block until all processes wants to sync. This is necessary before acquiring
// the lock below, because other processes might not enter sync() until it
// gets some response from this RpcAgent.
pg_->barrier()->wait();
// Acquire the lock on the send queue to prevent additional messages to be put
// onto the send queue.
std::unique_lock<std::mutex> lock(sendQueueMutex_);
// Wait until the send queue is depleted.
workConsumeCV_.wait(lock, [&] { return sendQueue_.empty(); });
// Use another barrier in case different RpcAgent handles different amounts of
// workloads.
pg_->barrier()->wait();
}
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
const std::string& to, Message&& message) {
auto dstRankIter = nameMap_.find(to);
TORCH_CHECK(dstRankIter != nameMap_.end(), "Unknown destination worker ", to);
const int dstRank = dstRankIter -> second;
TORCH_CHECK(dstRank != pg_->getRank(), "ProcessGroupAgent does not support "
"making RPC calls to self.")
auto requestId = nextId();
auto future = std::make_shared<FutureMessage>();
if (message.isRequest()) {
{
std::lock_guard<std::mutex> lock{futureMutex_};
futures_[requestId] = future;
}
message.setId(requestId);
} else {
future->markCompleted();
}
enqueue(SendWork(dstRank, std::move(message)));
return future;
}
void ProcessGroupAgent::enqueue(SendWork work) {
std::unique_lock<std::mutex> lock(sendQueueMutex_);
sendQueue_.emplace_back(std::move(work));
lock.unlock();
workProduceCV_.notify_one();
}
// making sure tensors are not deleted before send finishes
void ProcessGroupAgent::sendLoop() {
std::unique_lock<std::mutex> lock(sendQueueMutex_);
while (!stop_) {
if (sendQueue_.empty()) {
workProduceCV_.wait(lock);
continue;
}
auto work = std::move(sendQueue_.front());
sendQueue_.pop_front();
lock.unlock();
workConsumeCV_.notify_one();
std::stringstream ss;
serialize(work.message_, ss);
std::string str = ss.str();
std::vector<torch::Tensor> preamble = {
torch::tensor(
{
(int64_t)pg_->getRank(),
(int64_t)str.length(),
}, {torch::kLong})
};
pg_->send(preamble, work.dstRank_, work.dstRank_ /* channelTag */)->wait();
std::vector<torch::Tensor> payload =
{torch::from_blob((void *)str.c_str(), str.length(), {torch::kChar})};
pg_->send(payload, work.dstRank_, work.dstRank_ /* channelTag */)->wait();
lock.lock();
}
}
void ProcessGroupAgent::listenLoop() {
while (true) {
// rank, tensor size
std::vector<torch::Tensor> preamble = {torch::empty({2}, {torch::kInt64})};
pg_->recvAnysource(preamble, pg_->getRank())->wait();
int64_t* preamble_items = preamble.front().storage().data<int64_t>();
auto srcRank = preamble_items[0];
auto size = preamble_items[1];
std::vector<torch::Tensor> tensors = {torch::empty({size}, {torch::kChar})};
pg_->recv(tensors, srcRank, pg_->getRank())->wait();
std::stringstream ss(std::string(
(char*)tensors[0].storage().data<signed char>(), tensors[0].numel()));
Message message = deserialize(ss);
if (message.isRequest()) {
cb_(names_[srcRank], std::move(message), *this);
} else if (message.isResponse()) {
auto id = message.id();
{
std::lock_guard<std::mutex> lock{futureMutex_};
futures_[id]->markCompleted(std::move(message));
futures_.erase(id);
}
} else if (message.isShutdown()) {
break;
} else {
AT_ERROR("unrecognized message type ", message.type());
}
}
}
}
}
}

View File

@ -0,0 +1,75 @@
#pragma once
#include <c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/rpc/FutureMessage.h>
#include <torch/csrc/distributed/rpc/RpcAgent.h>
#include <torch/csrc/distributed/rpc/functions.h>
#include <deque>
#include <thread>
namespace torch {
namespace distributed {
namespace rpc {
struct SendWork {
SendWork(const int dstRank,
Message&& message)
: dstRank_(dstRank),
message_(message) {}
const int dstRank_;
Message message_;
};
class ProcessGroupAgent : public RpcAgent {
public:
ProcessGroupAgent(std::string workerName,
std::unordered_map<std::string, int> nameMap,
std::shared_ptr<c10d::ProcessGroup> pg);
// This method wraps the destination information and the message into a
// SendWork object, and put the SendWork into a queue. Another thread will
// consume SendWork from the queue and send it out.
std::shared_ptr<FutureMessage> send(
const std::string& to, Message&& message) override;
void join() override;
void sync() override;
private:
// put SendWork into a queue and notify the sendLoop thread
void enqueue(SendWork work);
// sending out the message
void sendLoop();
// receiving messages
void listenLoop();
int64_t nextId() {
return nextId_++;
}
// worker name -> rank
std::unordered_map<std::string, int> nameMap_;
bool stop_;
std::shared_ptr<c10d::ProcessGroup> pg_;
std::atomic<int64_t> nextId_;
// names_[rank] stores the name of the corresponding worker, use this vector
// to get worker name from rank and pass it to the RequestCallback.
std::vector<std::string> names_;
std::deque<SendWork> sendQueue_;
std::mutex sendQueueMutex_;
std::condition_variable workProduceCV_;
std::condition_variable workConsumeCV_;
std::thread sendThread_;
std::thread listenerThread_;
std::unordered_map<int64_t, std::shared_ptr<FutureMessage>> futures_;
std::mutex futureMutex_;
};
}
}
}

View File

@ -0,0 +1,14 @@
#include <torch/csrc/distributed/rpc/RpcAgent.h>
namespace torch {
namespace distributed {
namespace rpc {
RpcAgent::RpcAgent(std::string workerName, RequestCallback cb)
: workerName_(std::move(workerName)), cb_(std::move(cb)) {}
RpcAgent::~RpcAgent() = default;
}
}
}

View File

@ -0,0 +1,67 @@
#pragma once
#include <torch/csrc/distributed/rpc/FutureMessage.h>
#include <torch/csrc/distributed/rpc/Message.h>
namespace torch {
namespace distributed {
namespace rpc {
// RpcAgent is the base class for sending and receiving RPC messages. It
// provides a unified ``send`` API for both request and response messages, and
// will invoke the given ``RequestCallback`` to process received requests. It
// should immediately become ready to serve request and accept response after
// construction.
class RpcAgent;
// RpcAgent implementation should invoke ``RequestCallback`` to process received
// requests. There is no restriction on the implementation's threading model.
// This function takes the name of the request sender, the an rvalue reference
// of the Message object, and a reference to the RpcAgent itself. Having a
// reference to the RpcAgent allows the ``RequestCallback`` implementation to
// be both stateless and non-blocking. It may enqueue the message and the
// RpcAgent reference, and use a different set of threads to process them later.
using RequestCallback = std::function<void(std::string, Message&&, RpcAgent&)>;
class RpcAgent {
public:
// The ``workerName`` is the globally unique name for this RpcAgent. It is up
// to the RpcAgent implementation to determine how to resolve names.
// The ``RequestCallback`` will be invoked to handle received requests. This
// RpcAgent base class makes no assumption on the thread-safeness of the
// ``RequestCallback``. RpcAgent implementations need to make sure that its
// threading model conform to ``RequestCallback``'s requirement.
RpcAgent(std::string workerName, RequestCallback cb);
virtual ~RpcAgent();
// Send a message to the ``RpcAgent`` of name ``to`` and returns a
// ``FutureMessage`` ptr. The implementation must be asynchronous, i.e., it
// cannot block until it receives the response.
//
// If ``message.isRequest()`` is true, the ``FutureMessage`` will be completed
// when the response arrives. For other message types, the Future should be
// ignored by the caller.
//
// TODO: avoid passing strings all the time, e.g., by using symbols as a
// faster alternative.
virtual std::shared_ptr<FutureMessage> send(
const std::string& to, Message&& message) = 0;
// Call sync and join all internal threads. This method should be called
// before every RPC process exits.
virtual void join() = 0;
// Synchronize the this process with other RpcAgent processes. Block until all
// RpcAgents reach this method and send all pending messages.
virtual void sync() = 0;
protected:
const std::string workerName_;
const RequestCallback cb_;
};
}
}
}

View File

@ -0,0 +1,104 @@
#include <torch/csrc/distributed/rpc/ScriptCall.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace {
using torch::jit::Pickler;
using torch::jit::Unpickler;
} // namespace
const std::string ScriptCall::BUILTIN_OP_NAMESPACE_("torch.ops.aten.");
const std::string ScriptCall::ATEN_PREFIX_("aten::");
ScriptCall::ScriptCall(
std::shared_ptr<Operator> op, std::vector<at::IValue>&& args)
: op_(std::move(op)), stack_(args) {}
std::shared_ptr<Operator> ScriptCall::op() const {
return *op_;
}
const std::vector<at::IValue>& ScriptCall::stack() const {
return stack_;
}
Message ScriptCall::toMessage() {
std::vector<torch::Tensor> tensor_table;
Pickler pickler(&tensor_table);
pickler.protocol();
pickler.startTuple();
for (auto& value: stack_) {
pickler.pushIValue(value);
}
if (op_) {
// builtin ops
// TODO: replace this with a real overload_name when FunctionSchema supports
// that.
pickler.pushIValue(toString((*op_)->schema()));
// insert qualified name
auto opName = (*op_)->schema().name();
TORCH_CHECK(opName.find("::") == opName.rfind("::") &&
opName.rfind(ATEN_PREFIX_) == 0,
"Unexpected operator name ", opName);
// aten::add -> torch.ops.aten.add
opName.replace(0, ATEN_PREFIX_.length(), BUILTIN_OP_NAMESPACE_);
pickler.pushIValue(opName);
}
pickler.endTuple();
pickler.stop();
auto payload = pickler.stack();
return Message(std::move(payload),
std::move(tensor_table),
MessageType::SCRIPT_CALL);
}
ScriptCall ScriptCall::fromMessage(const Message& message) {
auto payload = static_cast<const void*>(message.payload().data());
auto payload_size = message.payload().size();
Unpickler unpickler(payload, payload_size, &message.tensors(), nullptr);
std::vector<IValue> values = unpickler.parse_ivalue_list();
TORCH_CHECK(values.size() >= 1, "Message of a ScriptCall must at least "
"contain one IValue as the operator schema.");
const std::string& qualifiedName = values.back().toStringRef();
if (qualifiedName.rfind(BUILTIN_OP_NAMESPACE_) == 0) {
values.pop_back();
const std::string& str_schema = values.back().toStringRef();
// extract symbol from the schema
auto str_symbol = str_schema.substr(0, str_schema.find('('));
auto symbol = at::Symbol::fromQualString(str_symbol);
auto op = matchOperator(symbol, str_schema);
// remove str_schema from values
values.pop_back();
return ScriptCall(op, std::move(values));
} else {
AT_ERROR("Unrecognized qualified name ", qualifiedName);
}
}
std::shared_ptr<Operator> ScriptCall::matchOperator(
at::Symbol& symbol, const std::string& str_schema) {
// TODO: This is a temporary solution. We should pass enough information to
// allow deterministically matched to one operator.
for (auto op: torch::jit::getAllOperatorsFor(symbol)) {
if (toString(op->schema()).compare(str_schema) == 0) {
return op;
}
}
AT_ERROR("Cannot find matching operator for schema ", str_schema);
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -0,0 +1,46 @@
#pragma once
#include <c10/util/Optional.h>
#include <torch/csrc/distributed/rpc/Message.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/pickler.h>
#include <vector>
namespace torch {
namespace distributed {
namespace rpc {
using torch::jit::Operator;
// A ScriptCall instance represents an invocation of a builtin operator for a
// TorchScript function (not implemented yet). If it is a builtin operator, it
// contains a shared ptr to the `Operator` and a list of arguments.
class TORCH_API ScriptCall final {
public:
ScriptCall(std::shared_ptr<Operator> op, std::vector<at::IValue>&& args);
std::shared_ptr<Operator> op() const;
// return the argument stack of this builtin operator
const std::vector<at::IValue>& stack() const;
Message toMessage();
static ScriptCall fromMessage(const Message& message);
private:
// Given an operator symbol and a string schema, return the matched operator.
static std::shared_ptr<Operator> matchOperator(
at::Symbol& symbol, const std::string& str_schema);
static const std::string BUILTIN_OP_NAMESPACE_;
static const std::string ATEN_PREFIX_;
// This field has value if this ScriptCall represents invocation of a builtin
// operator.
c10::optional<std::shared_ptr<Operator>> op_;
const std::vector<at::IValue> stack_;
};
}
}
}

View File

@ -0,0 +1,50 @@
#include <torch/csrc/distributed/rpc/ScriptRet.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace {
using torch::jit::Pickler;
using torch::jit::Unpickler;
} // namespace
ScriptRet::ScriptRet(at::IValue&& value) : value_(value) {}
const at::IValue& ScriptRet::value() {
return value_;
}
Message ScriptRet::toMessage() {
std::vector<torch::Tensor> tensor_table;
Pickler pickler(&tensor_table);
pickler.protocol();
pickler.startTuple();
pickler.pushIValue(value_);
pickler.endTuple();
pickler.stop();
auto payload = pickler.stack();
return Message(std::move(payload),
std::move(tensor_table),
MessageType::SCRIPT_RET);
}
ScriptRet ScriptRet::fromMessage(const Message& message) {
auto payload = static_cast<const void*>(message.payload().data());
auto payload_size = message.payload().size();
Unpickler unpickler(payload, payload_size, &message.tensors(), nullptr);
auto values = unpickler.parse_ivalue_list();
AT_ASSERT(values.size() == 1, "Return value of a builtin operator or a "
"TorchScript function should be a single IValue, got a vector of size ",
values.size());
return ScriptRet(std::move(values.front()));
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -0,0 +1,25 @@
#pragma once
#include <torch/csrc/distributed/rpc/Message.h>
#include <torch/csrc/jit/pickler.h>
namespace torch {
namespace distributed {
namespace rpc {
// Return value of a builtin operator or a TorchScript function.
class TORCH_API ScriptRet final {
public:
explicit ScriptRet(at::IValue&& values);
const at::IValue& value();
Message toMessage();
static ScriptRet fromMessage(const Message& message);
private:
const at::IValue value_;
};
}
}
}

View File

@ -0,0 +1,32 @@
#include <torch/csrc/distributed/rpc/functions.h>
namespace torch {
namespace distributed {
namespace rpc {
void processRequestBlocking(
const std::string& from, Message&& request, RpcAgent& agent) {
switch (request.type()) {
case MessageType::SCRIPT_CALL: {
ScriptCall op = ScriptCall::fromMessage(request);
auto stack = op.stack();
op.op()->getOperation()(stack);
AT_ASSERT(stack.size() == 1, "Return value of a builtin operator or a "
"TorchScript function should be a single IValue, got a vector of "
"size ", stack.size());
auto response = ScriptRet(std::move(stack.front())).toMessage();
response.setId(request.id());
agent.send(from, std::move(response));
break;
}
default: {
AT_ERROR("Request type ", request.type(), " not supported.");
}
}
}
}
}
}

View File

@ -0,0 +1,18 @@
#pragma once
#include <torch/csrc/distributed/rpc/FutureMessage.h>
#include <torch/csrc/distributed/rpc/Message.h>
#include <torch/csrc/distributed/rpc/RpcAgent.h>
#include <torch/csrc/distributed/rpc/ScriptCall.h>
#include <torch/csrc/distributed/rpc/ScriptRet.h>
namespace torch {
namespace distributed {
namespace rpc {
void processRequestBlocking(
const std::string& from, Message&& message, RpcAgent& agent);
} // rpc
} // distributed
} // torch

View File

@ -0,0 +1,83 @@
#include <torch/csrc/python_headers.h>
#include <torch/csrc/distributed/rpc/FutureMessage.h>
#include <torch/csrc/distributed/rpc/ProcessGroupAgent.h>
#include <torch/csrc/distributed/rpc/RpcAgent.h>
#include <torch/csrc/distributed/rpc/functions.h>
#include <torch/csrc/distributed/rpc/python_functions.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/types.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace {
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
PyObject* rpc_init(PyObject* /* unused */) {
auto dist_module = THPObjectPtr(PyImport_ImportModule("torch.distributed"));
if (!dist_module) {
throw python_error();
}
auto module = py::handle(dist_module).cast<py::module>();
auto rpcAgent = shared_ptr_class_<RpcAgent>(module, "RpcAgent")
.def("join",
&RpcAgent::join,
py::call_guard<py::gil_scoped_release>())
.def("sync",
&RpcAgent::sync,
py::call_guard<py::gil_scoped_release>());
auto futureMessage = shared_ptr_class_<FutureMessage>(module, "FutureMessage")
.def("wait",
[&](FutureMessage& fut) {
return to_py_obj(fut.wait());
},
py::call_guard<py::gil_scoped_release>());
auto processGroupAgent =
shared_ptr_class_<ProcessGroupAgent>(
module, "ProcessGroupAgent", rpcAgent)
.def(py::init<std::string,
std::unordered_map<std::string, int>,
std::shared_ptr<::c10d::ProcessGroup>>())
.def("join",
&ProcessGroupAgent::join,
py::call_guard<py::gil_scoped_release>())
.def("sync",
&ProcessGroupAgent::sync,
py::call_guard<py::gil_scoped_release>());
module.def("invoke_rpc", [](
RpcAgent& agent,
const std::string& dstName,
const std::string& opName,
const py::args& args,
const py::kwargs& kwargs) {
return py_rpc(agent, dstName, opName, args, kwargs);
});
Py_RETURN_TRUE;
}
} // namespace
static PyMethodDef methods[] = { // NOLINT
{"_rpc_init", (PyCFunction)rpc_init, METH_NOARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};
PyMethodDef* python_functions() {
return methods;
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -0,0 +1,49 @@
#include <torch/csrc/distributed/rpc/python_functions.h>
namespace torch {
namespace distributed {
namespace rpc {
py::object to_py_obj(const Message& message) {
switch (message.type()) {
case MessageType::SCRIPT_RET: {
ScriptRet ret = ScriptRet::fromMessage(message);
Stack stack;
stack.push_back(ret.value());
return torch::jit::createPyObjectForStack(std::move(stack));
}
default: {
AT_ERROR("Unrecognized response message type ", message.type());
}
}
}
std::shared_ptr<FutureMessage> py_rpc(
RpcAgent& agent,
const std::string& dstName,
const std::string& opName,
const py::args& args,
const py::kwargs& kwargs) {
if (opName.rfind("aten", 0) == 0) {
// builtin operators.
Symbol symbol = Symbol::fromQualString(opName);
for (const auto& op: torch::jit::getAllOperatorsFor(symbol)) {
try {
// FIXME: This is temporary solution. We should at least refactor
// ``createStackForSchema`` to avoid throwing an error.
Stack stack = torch::jit::createStackForSchema(
op->schema(), args, kwargs, c10::nullopt);
return agent.send(
dstName, ScriptCall(op, std::move(stack)).toMessage());
} catch (std::runtime_error) {}
}
}
AT_ERROR("Failed to match operator name ", opName, " and arguments "
"(args: ", args, ", kwargs: ", kwargs, ") to a builtin operator");
}
}
}
}

View File

@ -0,0 +1,28 @@
#pragma once
#include <torch/csrc/distributed/rpc/FutureMessage.h>
#include <torch/csrc/distributed/rpc/Message.h>
#include <torch/csrc/distributed/rpc/RpcAgent.h>
#include <torch/csrc/distributed/rpc/ScriptCall.h>
#include <torch/csrc/distributed/rpc/ScriptRet.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
namespace distributed {
namespace rpc {
py::object to_py_obj(const Message& message);
std::shared_ptr<FutureMessage> py_rpc(
RpcAgent& agent,
const std::string& dstName,
const std::string& opName,
const py::args& args,
const py::kwargs& kwargs);
}
}
}

View File

@ -0,0 +1,13 @@
#pragma once
#include <torch/csrc/python_headers.h>
namespace torch {
namespace distributed {
namespace rpc {
PyMethodDef* python_functions();
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -2,10 +2,10 @@ import torch
def is_available():
return hasattr(torch._C, "_c10d_init")
return hasattr(torch._C, "_c10d_init") and hasattr(torch._C, "_rpc_init")
if is_available() and not torch._C._c10d_init():
if is_available() and not (torch._C._c10d_init() and torch._C._rpc_init()):
raise RuntimeError("Failed to initialize PyTorch distributed support")
@ -15,3 +15,4 @@ if is_available():
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
# this.
from .distributed_c10d import _backend # noqa: F401
from .rpc import * # noqa: F401

174
torch/distributed/rpc.py Normal file
View File

@ -0,0 +1,174 @@
from . import invoke_rpc
from . import ProcessGroupAgent
import array
import sys
import torch
_agent = None
def _collect_worker_names(name, group):
from . import all_gather
from . import get_world_size
# collect name length
ws = get_world_size(group)
name_bytes = name if sys.version_info < (3, 0) else bytes(name, 'utf8')
name_bytes = list(array.array('B', name_bytes))
name_len = len(name_bytes)
len_input = torch.ones(1, dtype=torch.int64) * name_len
len_outputs = [torch.empty(1, dtype=torch.int64) for _ in range(ws)]
all_gather(len_outputs, len_input, group=group)
# collect name value
max_len = torch.stack(len_outputs).max().item()
name_input = torch.empty(max_len, dtype=torch.uint8)
name_input[:name_len] = torch.tensor(name_bytes, dtype=torch.uint8)
name_outputs = [torch.empty(max_len, dtype=torch.uint8) for _ in range(ws)]
all_gather(name_outputs, name_input, group=group)
names = []
for i in range(ws):
name_tensor = name_outputs[i][:len_outputs[i]]
names.append(bytearray(name_tensor.tolist()).decode('utf8'))
return names
def join_rpc():
r"""
Block until all local and remote RPC processes reach this method, process
(send and receive) all pending messages, and then destroy local RPC agent.
Every RPC process must call this method before exit.
"""
global _agent
if _agent:
_agent.join()
_agent = None
def sync_rpc():
r"""
Block until all local and remote RPC processes reach this method and finish
sending all pending RPCs. As this method synchronizes at the process
level, if multiple threads are spawned, only one of them should call this
method at a time.
"""
if _agent is None:
raise RuntimeError("RPC has not been initialized. "
"Call init_rpc(name) first.")
_agent.sync()
# TODO: add a context managet to wrap init_rpc and join_rpc
def init_rpc(name, backend='pg'):
r"""
Initialize the local RPC agent which immediately makes the current process
ready to send and receive RPCs. The caller needs to make sure the specified
backend is properly intialized before calling this method. For example, to
use ``pg`` (ProcessGroup) backend, ``init_process_group`` must be invoked
prior to this method.
Arguments:
name (str): a globally unique name of the local RPC agent. (e.g.,
``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``)
backend (str): type of RPC backend implementation. Currently,
process group backend ``"pg"`` is the only available
backend implementation. (default: ``"pg"``).
"""
global _agent
if _agent:
raise RuntimeError("RPC is already initialized")
if backend == 'pg':
from .distributed_c10d import _get_default_group
group = _get_default_group()
# TODO: issue #23232
names = _collect_worker_names(name, group)
name_dict = {names[r] : r for r in range(len(names))}
_agent = ProcessGroupAgent(name, name_dict, group)
else:
raise RuntimeError("Unrecognized RPC backend ", backend)
def rpc(to, func, args=None, kwargs=None, async_call=False):
r"""
Make an RPC call to run function ``func`` on worker ``to``. By default, it
blocks until the return value is locally available. RPC messages are sent
and received in parallel to execution of Python code. This method is
thread-safe.
Arguments:
to (str): name of the destination worker.
func (callable): a builtin function (e.g., ``torch.add``).
args (tuple): the argument tuple for the ``func`` invocation.
kwargs (dict): is a dictionary of keyword arguments for the ``func``
invocation.
async_call (bool): If set to ``True``, this will be an asynchronous RPC,
and returns a ``torch.distributed.FutureMessage``
object immediately. Otherwise, this RPC will block
until the return value is locally available.
(default: ``False``)
Returns:
If ``async_call`` is ``False``, returns the result of running ``func``
on ``args`` and ``kwargs``. If ``async_call`` is ``True``, returns a
``torch.distributed.FutureMessage`` object that can be waited on. When
completed, the return value of ``func`` on ``args`` and ``kwargs`` can
be retrieved from the ``FutureMessage`` object.
Example::
Synchronous example:
On worker 0:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> dist.init_rpc("worker0")
>>> ret = dist.rpc("worker1", torch.add, args=(torch.ones(2), 3))
>>> dist.join_rpc()
One worker 1:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> dist.init_rpc("worker1")
>>> dist.join_rpc()
Asynchronous example:
On worker 0:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> dist.init_rpc("worker0")
>>> fut1 = dist.rpc("worker1", torch.add, args=(torch.ones(2), 3), async_call=True)
>>> fut2 = dist.rpc("worker1", torch.add, args=(torch.ones(2), 2), async_call=True)
>>> result = fut1.wait() + fut2.wait()
>>> dist.join_rpc()
One worker 1:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> dist.init_rpc("worker1")
>>> dist.join_rpc()
"""
if _agent is None:
raise RuntimeError("RPC has not been initialized. "
"Call init_rpc(name) first.")
qualified_name = torch.jit._find_builtin(func)
if qualified_name is None:
raise RuntimeError("unknown builtin function %s." % func)
args = args if args else ()
kwargs = kwargs if kwargs else {}
fut = invoke_rpc(_agent, to, qualified_name, *args, **kwargs)
if async_call:
return fut
else:
return fut.wait()