mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
sync and async torch.distributed.rpc for builtin operators (#23228)
Summary: Features: * sync and async RPC for builtin operators * RpcAgent API * ProcessGroupAgent implementation Goal: * have a minimum working and testable RPC implementation * make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object. * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...). * support blocking and non-blocking RequestCallback * blocking means the callback won't return before sending out the response * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list. We are not exporting this diff until we finalize distributed autograd design and publish the API review publicly. https://fb.quip.com/FabTAZKVgQpf Pull Request resolved: https://github.com/pytorch/pytorch/pull/23228 ghstack-source-id: 87816717 Reviewed By: zhaojuanmao Differential Revision: D15194693 fbshipit-source-id: 7adb600796613cde6073db6c227451b89940ecaf
This commit is contained in:
committed by
Facebook Github Bot
parent
c07fc96b94
commit
8b349073ce
@ -458,13 +458,17 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
)
|
||||
|
||||
if (NOT INTERN_BUILD_MOBILE)
|
||||
list(APPEND TORCH_SRCS
|
||||
list(APPEND TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/FutureMessage.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/Message.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/ScriptCall.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/ScriptRet.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/export.cpp
|
||||
${TORCH_ROOT}/test/cpp/jit/test.cpp
|
||||
)
|
||||
if (NOT WIN32)
|
||||
list(APPEND TORCH_SRCS
|
||||
list(APPEND TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -48,6 +48,7 @@ TESTS = [
|
||||
'quantized',
|
||||
'quantized_tensor',
|
||||
'quantizer',
|
||||
'rpc',
|
||||
'sparse',
|
||||
'torch',
|
||||
'type_info',
|
||||
|
||||
114
test/test_rpc.py
Normal file
114
test/test_rpc.py
Normal file
@ -0,0 +1,114 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from common_distributed import MultiProcessTestCase
|
||||
from common_utils import load_tests, run_tests
|
||||
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
load_tests = load_tests
|
||||
|
||||
|
||||
if not dist.is_available():
|
||||
print('c10d not available, skipping tests')
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def _wrap_with_rpc(func):
|
||||
def wrapper(self):
|
||||
store = dist.FileStore(self.file.name, self.world_size)
|
||||
dist.init_process_group(backend='gloo', rank=self.rank,
|
||||
world_size=self.world_size, store=store)
|
||||
dist.init_rpc('worker%d' % self.rank)
|
||||
func(self)
|
||||
dist.join_rpc()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RpcTest(MultiProcessTestCase):
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return 4
|
||||
|
||||
@_wrap_with_rpc
|
||||
def test_add(self):
|
||||
n = self.rank + 1
|
||||
dstRank = n % self.world_size
|
||||
ret = dist.rpc('worker%d' % dstRank, torch.add,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)))
|
||||
self.assertEqual(ret, torch.ones(n, n) * 2)
|
||||
|
||||
@_wrap_with_rpc
|
||||
def test_scalar_add(self):
|
||||
n = self.rank + 1
|
||||
dstRank = n % self.world_size
|
||||
ret = dist.rpc('worker%d' % dstRank, torch.add,
|
||||
args=(torch.ones(n, n), n))
|
||||
self.assertEqual(ret, (torch.ones(n, n) + n))
|
||||
|
||||
@_wrap_with_rpc
|
||||
def test_async_add(self):
|
||||
n = self.rank + 1
|
||||
dstRank = n % self.world_size
|
||||
fut = dist.rpc('worker%d' % dstRank,
|
||||
torch.add,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||
async_call=True)
|
||||
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
|
||||
|
||||
@_wrap_with_rpc
|
||||
def test_nonzero(self):
|
||||
n = self.rank + 1
|
||||
dstRank = n % self.world_size
|
||||
x = torch.ones(self.world_size, self.world_size)
|
||||
x[self.rank][self.rank] = 0
|
||||
ret = dist.rpc('worker%d' % dstRank, torch.nonzero, args=(x,))
|
||||
self.assertEqual(ret, x.nonzero())
|
||||
|
||||
@_wrap_with_rpc
|
||||
def test_multi_rpc(self):
|
||||
dstRank = (self.rank + 1) % self.world_size
|
||||
for i in range(20):
|
||||
n = i + self.rank + 1
|
||||
ret = dist.rpc('worker%d' % dstRank, torch.add,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)))
|
||||
self.assertEqual(ret, torch.ones(n, n) * 2)
|
||||
|
||||
@_wrap_with_rpc
|
||||
def test_sync_rpc(self):
|
||||
dstRank = (self.rank + 1) % self.world_size
|
||||
for i in range(20):
|
||||
dist.sync_rpc()
|
||||
n = i + self.rank + 1
|
||||
ret1 = dist.rpc('worker%d' % dstRank, torch.add,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)))
|
||||
dist.sync_rpc()
|
||||
ret2 = dist.rpc('worker%d' % dstRank, torch.add,
|
||||
args=(torch.ones(n, n), 2))
|
||||
dist.sync_rpc()
|
||||
self.assertEqual(ret1, torch.ones(n, n) * 2)
|
||||
self.assertEqual(ret2, torch.ones(n, n) * 3)
|
||||
|
||||
@_wrap_with_rpc
|
||||
def test_join_rpc(self):
|
||||
n = self.rank + 1
|
||||
dstRank = n % self.world_size
|
||||
ret = dist.rpc('worker%d' % dstRank, torch.add,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)))
|
||||
self.assertEqual(ret, torch.ones(n, n) * 2)
|
||||
dist.join_rpc()
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
|
||||
dist.rpc('worker%d' % dstRank, torch.add,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)))
|
||||
|
||||
# it's safe to call join_rpc() multiple times
|
||||
dist.join_rpc()
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
@ -49,6 +49,10 @@ libtorch_sources = [
|
||||
"torch/csrc/autograd/record_function.cpp",
|
||||
"torch/csrc/autograd/saved_variable.cpp",
|
||||
"torch/csrc/autograd/variable.cpp",
|
||||
"torch/csrc/distributed/rpc/FutureMessage.cpp",
|
||||
"torch/csrc/distributed/rpc/Message.cpp",
|
||||
"torch/csrc/distributed/rpc/ScriptCall.cpp",
|
||||
"torch/csrc/distributed/rpc/ScriptRet.cpp",
|
||||
"torch/csrc/Exceptions.cpp",
|
||||
"torch/csrc/jit/autodiff.cpp",
|
||||
"torch/csrc/jit/attributes.cpp",
|
||||
@ -229,6 +233,11 @@ def add_torch_libs():
|
||||
"torch/csrc/distributed/c10d/comm.cpp",
|
||||
"torch/csrc/distributed/c10d/init.cpp",
|
||||
"torch/csrc/distributed/c10d/reducer.cpp",
|
||||
"torch/csrc/distributed/rpc/init.cpp",
|
||||
"torch/csrc/distributed/rpc/RpcAgent.cpp",
|
||||
"torch/csrc/distributed/rpc/ProcessGroupAgent.cpp",
|
||||
"torch/csrc/distributed/rpc/functions.cpp",
|
||||
"torch/csrc/distributed/rpc/python_functions.cpp",
|
||||
"torch/csrc/jit/init.cpp",
|
||||
"torch/csrc/jit/passes/inline_fork_wait.cpp",
|
||||
"torch/csrc/jit/passes/onnx.cpp",
|
||||
|
||||
@ -224,6 +224,12 @@ if (USE_DISTRIBUTED)
|
||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/comm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/reducer.cpp
|
||||
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/init.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/RpcAgent.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/ProcessGroupAgent.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/functions.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_functions.cpp
|
||||
)
|
||||
list(APPEND TORCH_PYTHON_LINK_LIBRARIES c10d)
|
||||
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D)
|
||||
|
||||
@ -55,6 +55,7 @@
|
||||
#ifdef USE_C10D
|
||||
#include <torch/csrc/distributed/c10d/c10d.h>
|
||||
#endif
|
||||
#include <torch/csrc/distributed/rpc/rpc.h>
|
||||
#endif
|
||||
|
||||
#define WITH_NUMPY_IMPORT_ARRAY
|
||||
@ -626,6 +627,7 @@ PyObject* initModule() {
|
||||
#ifdef USE_C10D
|
||||
THPUtils_addPyMethodDefs(methods, torch::distributed::c10d::python_functions());
|
||||
#endif
|
||||
THPUtils_addPyMethodDefs(methods, torch::distributed::rpc::python_functions());
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
|
||||
62
torch/csrc/distributed/rpc/FutureMessage.cpp
Normal file
62
torch/csrc/distributed/rpc/FutureMessage.cpp
Normal file
@ -0,0 +1,62 @@
|
||||
#include <torch/csrc/distributed/rpc/FutureMessage.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
const Message& FutureMessage::wait() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
finished_cv_.wait(lock, [this]{return completed_.load();});
|
||||
return message_;
|
||||
}
|
||||
|
||||
void FutureMessage::markCompleted(Message message) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
TORCH_CHECK(!completed());
|
||||
completed_ = true;
|
||||
message_ = std::move(message);
|
||||
|
||||
fireCallbacks();
|
||||
}
|
||||
finished_cv_.notify_all();
|
||||
}
|
||||
|
||||
void FutureMessage::markCompleted() {
|
||||
markCompleted(Message());
|
||||
}
|
||||
|
||||
const Message& FutureMessage::message() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
TORCH_CHECK(completed(), "Cannot retrieve message before completed.");
|
||||
|
||||
return message_;
|
||||
}
|
||||
|
||||
bool FutureMessage::completed() const {
|
||||
return completed_;
|
||||
}
|
||||
|
||||
void FutureMessage::addCallback(const FutureMessage::Callback& callback) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (completed()) {
|
||||
lock.unlock();
|
||||
callback(message_);
|
||||
return;
|
||||
}
|
||||
callbacks.push_back(callback);
|
||||
}
|
||||
|
||||
void FutureMessage::fireCallbacks() {
|
||||
TORCH_CHECK(completed(), "Firing callbacks on incomplete FutureMessage.");
|
||||
// There is no need to protect callbacks with the lock.
|
||||
// Once completed_ is set to true, no one can add new callback to the list.
|
||||
for (auto& callback : callbacks) {
|
||||
callback(message_);
|
||||
}
|
||||
callbacks.clear();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
43
torch/csrc/distributed/rpc/FutureMessage.h
Normal file
43
torch/csrc/distributed/rpc/FutureMessage.h
Normal file
@ -0,0 +1,43 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/rpc/Message.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
|
||||
// This class holds a message that will be ready in the future.
|
||||
//
|
||||
// TODO: consider using ivalue::Future.
|
||||
struct TORCH_API FutureMessage final {
|
||||
|
||||
public:
|
||||
using Callback = std::function<void(const Message&)>;
|
||||
|
||||
// TODO: add a get() API that returns immediately with an optional Message
|
||||
// object.
|
||||
const Message& wait();
|
||||
void markCompleted(Message message);
|
||||
void markCompleted();
|
||||
const Message& message();
|
||||
bool completed() const;
|
||||
|
||||
// If completed() the callback will be invoked in-place.
|
||||
void addCallback(const Callback& callback);
|
||||
|
||||
private:
|
||||
|
||||
void fireCallbacks();
|
||||
|
||||
std::mutex mutex_;
|
||||
std::atomic_bool completed_ {false}; // is this future complete
|
||||
std::condition_variable finished_cv_;
|
||||
std::vector<Callback> callbacks;
|
||||
// TODO: make message_ an optional field, and get rid of UNKNOWN message type
|
||||
Message message_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
88
torch/csrc/distributed/rpc/Message.cpp
Normal file
88
torch/csrc/distributed/rpc/Message.cpp
Normal file
@ -0,0 +1,88 @@
|
||||
#include <torch/csrc/distributed/rpc/Message.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
Message::Message() = default;
|
||||
|
||||
Message::Message(
|
||||
std::vector<char>&& payload,
|
||||
std::vector<torch::Tensor>&& tensors,
|
||||
MessageType type)
|
||||
: payload_(payload), tensors_(tensors), type_(type) {}
|
||||
|
||||
Message::Message(
|
||||
std::vector<char>&& payload,
|
||||
std::vector<torch::Tensor>&& tensors,
|
||||
MessageType type,
|
||||
int64_t id)
|
||||
: payload_(payload), tensors_(tensors), type_(type), id_(id) {}
|
||||
|
||||
Message::Message(const Message& other) = default;
|
||||
|
||||
Message::Message(Message&& other) noexcept = default;
|
||||
|
||||
Message& Message::operator=(Message const& rhs) & {
|
||||
auto payload = rhs.payload_;
|
||||
auto tensors = rhs.tensors_;
|
||||
Message(std::move(payload),
|
||||
std::move(tensors),
|
||||
rhs.type_,
|
||||
rhs.id_).swap(*this);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Message& Message::operator=(Message&& rhs) & {
|
||||
Message(std::move(rhs.payload_),
|
||||
std::move(rhs.tensors_),
|
||||
rhs.type_,
|
||||
rhs.id_).swap(*this);
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Message::swap(Message& rhs) noexcept {
|
||||
std::swap(payload_, rhs.payload_);
|
||||
std::swap(tensors_, rhs.tensors_);
|
||||
std::swap(type_, rhs.type_);
|
||||
std::swap(id_, rhs.id_);
|
||||
}
|
||||
|
||||
const std::vector<char>& Message::payload() const {
|
||||
return payload_;
|
||||
}
|
||||
|
||||
const std::vector<torch::Tensor>& Message::tensors() const {
|
||||
return tensors_;
|
||||
}
|
||||
|
||||
const MessageType& Message::type() const {
|
||||
return type_;
|
||||
}
|
||||
|
||||
bool Message::isRequest() const {
|
||||
return MessageType::SCRIPT_CALL == type_
|
||||
|| MessageType::PYTHON_CALL == type_;
|
||||
}
|
||||
|
||||
bool Message::isResponse() const {
|
||||
return MessageType::SCRIPT_RET == type_
|
||||
|| MessageType::PYTHON_RET == type_;
|
||||
}
|
||||
|
||||
bool Message::isShutdown() const {
|
||||
return MessageType::SHUTDOWN == type_;
|
||||
}
|
||||
|
||||
int64_t Message::id() const {
|
||||
return id_;
|
||||
}
|
||||
|
||||
void Message::setId(int64_t id) {
|
||||
id_ = id;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
81
torch/csrc/distributed/rpc/Message.h
Normal file
81
torch/csrc/distributed/rpc/Message.h
Normal file
@ -0,0 +1,81 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/serialize.h>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
enum MessageType {
|
||||
SCRIPT_CALL = 0,
|
||||
SCRIPT_RET,
|
||||
PYTHON_CALL,
|
||||
PYTHON_RET,
|
||||
SHUTDOWN,
|
||||
UNKNOWN
|
||||
};
|
||||
|
||||
// A message to be sent/received by an RpcAgent.
|
||||
//
|
||||
// A Message object contains 4 fields:
|
||||
// payload (std::vector<char>): a binary chunk of data.
|
||||
// tensors (std::vector<torch::Tensor>): all tensors. Tensor data are not
|
||||
// included in the payload, and it is up to the RpcAgent implementation
|
||||
// to determine how to serialize them. This design is helpful for
|
||||
// communicating super large tensors where serializing all the data at
|
||||
// once leads to excessively large memory footprint. An implementation
|
||||
// can then serialize and send tensors chunck-by-chunk, in the streaming
|
||||
// fashion.
|
||||
// type (MessageType): type of the message.
|
||||
// id (int64_t): message id, this is used by ProcessGroupAgent to match
|
||||
// request and response. Other implementation can ignore it
|
||||
// if they have their own ways to do matching.
|
||||
//
|
||||
// Layers above ``RpcAgent`` only converts ScriptCall, ScriptRet, PythonCall,
|
||||
// and PythonRet into a Message, and it is up to the RpcAgent
|
||||
// implementation to determine how to serialize a message.
|
||||
class TORCH_API Message final {
|
||||
public:
|
||||
|
||||
Message();
|
||||
|
||||
Message(std::vector<char>&& payload,
|
||||
std::vector<torch::Tensor>&& tensors,
|
||||
MessageType type);
|
||||
|
||||
Message(std::vector<char>&& payload,
|
||||
std::vector<torch::Tensor>&& tensors,
|
||||
MessageType type,
|
||||
int64_t id);
|
||||
|
||||
Message(const Message& other);
|
||||
Message(Message&& other) noexcept;
|
||||
Message& operator=(Message const& rhs) &;
|
||||
Message& operator=(Message&& rhs) &;
|
||||
void swap(Message& rhs) noexcept;
|
||||
|
||||
const std::vector<char>& payload() const;
|
||||
const std::vector<torch::Tensor>& tensors() const;
|
||||
const MessageType& type() const;
|
||||
|
||||
bool isRequest() const;
|
||||
bool isResponse() const;
|
||||
bool isShutdown() const;
|
||||
|
||||
// id is an optional field to match request/response. If an RpcAgent
|
||||
// implementation is able to do the matching without using this id, it can be
|
||||
// dropped during message serialization.
|
||||
int64_t id() const;
|
||||
void setId(int64_t id);
|
||||
|
||||
private:
|
||||
std::vector<char> payload_;
|
||||
std::vector<torch::Tensor> tensors_;
|
||||
MessageType type_ = MessageType::UNKNOWN;
|
||||
int64_t id_ = -1;
|
||||
};
|
||||
|
||||
} // rpc
|
||||
} // distributed
|
||||
} // torch
|
||||
227
torch/csrc/distributed/rpc/ProcessGroupAgent.cpp
Normal file
227
torch/csrc/distributed/rpc/ProcessGroupAgent.cpp
Normal file
@ -0,0 +1,227 @@
|
||||
#include <torch/csrc/distributed/rpc/ProcessGroupAgent.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
namespace {
|
||||
|
||||
// Write the message into the given ostream
|
||||
void serialize(const Message& message, std::ostream& os) {
|
||||
// We cast const void* to void* here because we need to create a tensor using
|
||||
// that memory space. If is fine as that tensor stays function-local, and will
|
||||
// not be modified during its lifetime.
|
||||
auto payload = const_cast<void*>( // NOLINT
|
||||
static_cast<const void*>(message.payload().data()));
|
||||
auto payload_size = message.payload().size();
|
||||
|
||||
// getting tensor table from the message
|
||||
std::vector<torch::Tensor> tensors = message.tensors();
|
||||
// append payload as a tensor
|
||||
tensors.push_back(torch::from_blob(payload, payload_size, {torch::kChar}));
|
||||
// append id and type as a tensor
|
||||
tensors.push_back(torch::tensor(
|
||||
{message.id(), (int64_t) message.type()}, {torch::kInt64}
|
||||
));
|
||||
|
||||
torch::save(tensors, os);
|
||||
}
|
||||
|
||||
Message deserialize(std::istream& is) {
|
||||
std::vector<torch::Tensor> tensors;
|
||||
|
||||
torch::load(tensors, is);
|
||||
|
||||
TORCH_CHECK(tensors.size() >= 2, "Failed to deserialize a message.");
|
||||
auto miscTensor = std::move(tensors.back());
|
||||
tensors.pop_back();
|
||||
auto payloadTensor = std::move(tensors.back());
|
||||
tensors.pop_back();
|
||||
|
||||
int64_t* miscItems = miscTensor.storage().data<int64_t>();
|
||||
int64_t id = miscItems[0];
|
||||
MessageType type = MessageType(miscItems[1]);
|
||||
|
||||
std::vector<char> payload(payloadTensor.numel());
|
||||
|
||||
if (payloadTensor.numel() > 0) {
|
||||
std::memcpy(payload.data(),
|
||||
payloadTensor.storage().data(),
|
||||
payloadTensor.numel());
|
||||
}
|
||||
|
||||
return Message(std::move(payload), std::move(tensors), type, id);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ProcessGroupAgent::ProcessGroupAgent(
|
||||
std::string workerName,
|
||||
std::unordered_map<std::string, int> nameMap,
|
||||
std::shared_ptr<c10d::ProcessGroup> pg)
|
||||
: RpcAgent(std::move(workerName), processRequestBlocking),
|
||||
nameMap_(std::move(nameMap)),
|
||||
stop_(false),
|
||||
pg_(std::move(pg)),
|
||||
nextId_(0) {
|
||||
TORCH_CHECK(nameMap_.size() > 1, "ProcessGroupAgent requires world_size to "
|
||||
"be at least 2, but got ", nameMap_.size());
|
||||
auto workerRankIter = nameMap_.find(workerName_);
|
||||
TORCH_CHECK(workerRankIter != nameMap_.end(), "Failed to resolve worker "
|
||||
"name ", workerName_, " to a ProcessGroup rank.");
|
||||
TORCH_CHECK(pg_->getRank() == workerRankIter -> second,
|
||||
"Resolved worker rank ", workerRankIter -> second,
|
||||
" does not match ProcessGroup rank ", pg_->getRank());
|
||||
|
||||
names_.resize(nameMap_.size());
|
||||
for (auto& entry : nameMap_) {
|
||||
names_[entry.second] = entry.first;
|
||||
}
|
||||
sendThread_ = std::thread(&ProcessGroupAgent::sendLoop, this);
|
||||
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::join() {
|
||||
// Every process i sends a SHUTDOWN message to process i + 1. This is
|
||||
// necessary for now because:
|
||||
// 1. There is no abort API for ProcessGroup::recvAnysource yet. We have to
|
||||
// feed it a message or kill the thread.
|
||||
// 2. A GLOO process cannot send message to itself. (there is an ongoing
|
||||
// effort to fix this problem).
|
||||
sync();
|
||||
int dst = (pg_->getRank() + 1) % pg_->getSize();
|
||||
enqueue(SendWork(dst, Message({}, {}, MessageType::SHUTDOWN)));
|
||||
std::unique_lock<std::mutex> lock(sendQueueMutex_);
|
||||
workConsumeCV_.wait(lock, [&] { return sendQueue_.empty(); });
|
||||
stop_ = true;
|
||||
lock.unlock();
|
||||
|
||||
workProduceCV_.notify_all();
|
||||
sendThread_.join();
|
||||
listenerThread_.join();
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::sync() {
|
||||
// Block until all processes wants to sync. This is necessary before acquiring
|
||||
// the lock below, because other processes might not enter sync() until it
|
||||
// gets some response from this RpcAgent.
|
||||
pg_->barrier()->wait();
|
||||
// Acquire the lock on the send queue to prevent additional messages to be put
|
||||
// onto the send queue.
|
||||
std::unique_lock<std::mutex> lock(sendQueueMutex_);
|
||||
// Wait until the send queue is depleted.
|
||||
workConsumeCV_.wait(lock, [&] { return sendQueue_.empty(); });
|
||||
// Use another barrier in case different RpcAgent handles different amounts of
|
||||
// workloads.
|
||||
pg_->barrier()->wait();
|
||||
}
|
||||
|
||||
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
|
||||
const std::string& to, Message&& message) {
|
||||
|
||||
auto dstRankIter = nameMap_.find(to);
|
||||
TORCH_CHECK(dstRankIter != nameMap_.end(), "Unknown destination worker ", to);
|
||||
|
||||
const int dstRank = dstRankIter -> second;
|
||||
TORCH_CHECK(dstRank != pg_->getRank(), "ProcessGroupAgent does not support "
|
||||
"making RPC calls to self.")
|
||||
|
||||
auto requestId = nextId();
|
||||
auto future = std::make_shared<FutureMessage>();
|
||||
if (message.isRequest()) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{futureMutex_};
|
||||
futures_[requestId] = future;
|
||||
}
|
||||
message.setId(requestId);
|
||||
} else {
|
||||
future->markCompleted();
|
||||
}
|
||||
|
||||
enqueue(SendWork(dstRank, std::move(message)));
|
||||
return future;
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::enqueue(SendWork work) {
|
||||
std::unique_lock<std::mutex> lock(sendQueueMutex_);
|
||||
sendQueue_.emplace_back(std::move(work));
|
||||
lock.unlock();
|
||||
|
||||
workProduceCV_.notify_one();
|
||||
}
|
||||
|
||||
// making sure tensors are not deleted before send finishes
|
||||
void ProcessGroupAgent::sendLoop() {
|
||||
std::unique_lock<std::mutex> lock(sendQueueMutex_);
|
||||
|
||||
while (!stop_) {
|
||||
if (sendQueue_.empty()) {
|
||||
workProduceCV_.wait(lock);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto work = std::move(sendQueue_.front());
|
||||
sendQueue_.pop_front();
|
||||
lock.unlock();
|
||||
|
||||
workConsumeCV_.notify_one();
|
||||
|
||||
|
||||
std::stringstream ss;
|
||||
serialize(work.message_, ss);
|
||||
std::string str = ss.str();
|
||||
|
||||
std::vector<torch::Tensor> preamble = {
|
||||
torch::tensor(
|
||||
{
|
||||
(int64_t)pg_->getRank(),
|
||||
(int64_t)str.length(),
|
||||
}, {torch::kLong})
|
||||
};
|
||||
pg_->send(preamble, work.dstRank_, work.dstRank_ /* channelTag */)->wait();
|
||||
std::vector<torch::Tensor> payload =
|
||||
{torch::from_blob((void *)str.c_str(), str.length(), {torch::kChar})};
|
||||
pg_->send(payload, work.dstRank_, work.dstRank_ /* channelTag */)->wait();
|
||||
|
||||
lock.lock();
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::listenLoop() {
|
||||
while (true) {
|
||||
// rank, tensor size
|
||||
std::vector<torch::Tensor> preamble = {torch::empty({2}, {torch::kInt64})};
|
||||
pg_->recvAnysource(preamble, pg_->getRank())->wait();
|
||||
int64_t* preamble_items = preamble.front().storage().data<int64_t>();
|
||||
|
||||
auto srcRank = preamble_items[0];
|
||||
auto size = preamble_items[1];
|
||||
|
||||
std::vector<torch::Tensor> tensors = {torch::empty({size}, {torch::kChar})};
|
||||
pg_->recv(tensors, srcRank, pg_->getRank())->wait();
|
||||
|
||||
std::stringstream ss(std::string(
|
||||
(char*)tensors[0].storage().data<signed char>(), tensors[0].numel()));
|
||||
|
||||
Message message = deserialize(ss);
|
||||
|
||||
if (message.isRequest()) {
|
||||
cb_(names_[srcRank], std::move(message), *this);
|
||||
} else if (message.isResponse()) {
|
||||
auto id = message.id();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{futureMutex_};
|
||||
futures_[id]->markCompleted(std::move(message));
|
||||
futures_.erase(id);
|
||||
}
|
||||
} else if (message.isShutdown()) {
|
||||
break;
|
||||
} else {
|
||||
AT_ERROR("unrecognized message type ", message.type());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
75
torch/csrc/distributed/rpc/ProcessGroupAgent.h
Normal file
75
torch/csrc/distributed/rpc/ProcessGroupAgent.h
Normal file
@ -0,0 +1,75 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10d/ProcessGroup.hpp>
|
||||
#include <torch/csrc/distributed/rpc/FutureMessage.h>
|
||||
#include <torch/csrc/distributed/rpc/RpcAgent.h>
|
||||
#include <torch/csrc/distributed/rpc/functions.h>
|
||||
|
||||
#include <deque>
|
||||
#include <thread>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
struct SendWork {
|
||||
SendWork(const int dstRank,
|
||||
Message&& message)
|
||||
: dstRank_(dstRank),
|
||||
message_(message) {}
|
||||
|
||||
const int dstRank_;
|
||||
Message message_;
|
||||
|
||||
};
|
||||
|
||||
class ProcessGroupAgent : public RpcAgent {
|
||||
public:
|
||||
|
||||
ProcessGroupAgent(std::string workerName,
|
||||
std::unordered_map<std::string, int> nameMap,
|
||||
std::shared_ptr<c10d::ProcessGroup> pg);
|
||||
|
||||
// This method wraps the destination information and the message into a
|
||||
// SendWork object, and put the SendWork into a queue. Another thread will
|
||||
// consume SendWork from the queue and send it out.
|
||||
std::shared_ptr<FutureMessage> send(
|
||||
const std::string& to, Message&& message) override;
|
||||
|
||||
void join() override;
|
||||
|
||||
void sync() override;
|
||||
|
||||
private:
|
||||
// put SendWork into a queue and notify the sendLoop thread
|
||||
void enqueue(SendWork work);
|
||||
// sending out the message
|
||||
void sendLoop();
|
||||
// receiving messages
|
||||
void listenLoop();
|
||||
|
||||
int64_t nextId() {
|
||||
return nextId_++;
|
||||
}
|
||||
|
||||
// worker name -> rank
|
||||
std::unordered_map<std::string, int> nameMap_;
|
||||
bool stop_;
|
||||
std::shared_ptr<c10d::ProcessGroup> pg_;
|
||||
std::atomic<int64_t> nextId_;
|
||||
// names_[rank] stores the name of the corresponding worker, use this vector
|
||||
// to get worker name from rank and pass it to the RequestCallback.
|
||||
std::vector<std::string> names_;
|
||||
std::deque<SendWork> sendQueue_;
|
||||
std::mutex sendQueueMutex_;
|
||||
std::condition_variable workProduceCV_;
|
||||
std::condition_variable workConsumeCV_;
|
||||
std::thread sendThread_;
|
||||
std::thread listenerThread_;
|
||||
std::unordered_map<int64_t, std::shared_ptr<FutureMessage>> futures_;
|
||||
std::mutex futureMutex_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
14
torch/csrc/distributed/rpc/RpcAgent.cpp
Normal file
14
torch/csrc/distributed/rpc/RpcAgent.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
#include <torch/csrc/distributed/rpc/RpcAgent.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
RpcAgent::RpcAgent(std::string workerName, RequestCallback cb)
|
||||
: workerName_(std::move(workerName)), cb_(std::move(cb)) {}
|
||||
|
||||
RpcAgent::~RpcAgent() = default;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
67
torch/csrc/distributed/rpc/RpcAgent.h
Normal file
67
torch/csrc/distributed/rpc/RpcAgent.h
Normal file
@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/rpc/FutureMessage.h>
|
||||
#include <torch/csrc/distributed/rpc/Message.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
|
||||
// RpcAgent is the base class for sending and receiving RPC messages. It
|
||||
// provides a unified ``send`` API for both request and response messages, and
|
||||
// will invoke the given ``RequestCallback`` to process received requests. It
|
||||
// should immediately become ready to serve request and accept response after
|
||||
// construction.
|
||||
class RpcAgent;
|
||||
|
||||
// RpcAgent implementation should invoke ``RequestCallback`` to process received
|
||||
// requests. There is no restriction on the implementation's threading model.
|
||||
// This function takes the name of the request sender, the an rvalue reference
|
||||
// of the Message object, and a reference to the RpcAgent itself. Having a
|
||||
// reference to the RpcAgent allows the ``RequestCallback`` implementation to
|
||||
// be both stateless and non-blocking. It may enqueue the message and the
|
||||
// RpcAgent reference, and use a different set of threads to process them later.
|
||||
using RequestCallback = std::function<void(std::string, Message&&, RpcAgent&)>;
|
||||
|
||||
class RpcAgent {
|
||||
public:
|
||||
// The ``workerName`` is the globally unique name for this RpcAgent. It is up
|
||||
// to the RpcAgent implementation to determine how to resolve names.
|
||||
// The ``RequestCallback`` will be invoked to handle received requests. This
|
||||
// RpcAgent base class makes no assumption on the thread-safeness of the
|
||||
// ``RequestCallback``. RpcAgent implementations need to make sure that its
|
||||
// threading model conform to ``RequestCallback``'s requirement.
|
||||
RpcAgent(std::string workerName, RequestCallback cb);
|
||||
|
||||
virtual ~RpcAgent();
|
||||
|
||||
// Send a message to the ``RpcAgent`` of name ``to`` and returns a
|
||||
// ``FutureMessage`` ptr. The implementation must be asynchronous, i.e., it
|
||||
// cannot block until it receives the response.
|
||||
//
|
||||
// If ``message.isRequest()`` is true, the ``FutureMessage`` will be completed
|
||||
// when the response arrives. For other message types, the Future should be
|
||||
// ignored by the caller.
|
||||
//
|
||||
// TODO: avoid passing strings all the time, e.g., by using symbols as a
|
||||
// faster alternative.
|
||||
virtual std::shared_ptr<FutureMessage> send(
|
||||
const std::string& to, Message&& message) = 0;
|
||||
|
||||
// Call sync and join all internal threads. This method should be called
|
||||
// before every RPC process exits.
|
||||
virtual void join() = 0;
|
||||
|
||||
// Synchronize the this process with other RpcAgent processes. Block until all
|
||||
// RpcAgents reach this method and send all pending messages.
|
||||
virtual void sync() = 0;
|
||||
|
||||
protected:
|
||||
const std::string workerName_;
|
||||
const RequestCallback cb_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
104
torch/csrc/distributed/rpc/ScriptCall.cpp
Normal file
104
torch/csrc/distributed/rpc/ScriptCall.cpp
Normal file
@ -0,0 +1,104 @@
|
||||
#include <torch/csrc/distributed/rpc/ScriptCall.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
namespace {
|
||||
|
||||
using torch::jit::Pickler;
|
||||
using torch::jit::Unpickler;
|
||||
|
||||
} // namespace
|
||||
|
||||
const std::string ScriptCall::BUILTIN_OP_NAMESPACE_("torch.ops.aten.");
|
||||
const std::string ScriptCall::ATEN_PREFIX_("aten::");
|
||||
|
||||
ScriptCall::ScriptCall(
|
||||
std::shared_ptr<Operator> op, std::vector<at::IValue>&& args)
|
||||
: op_(std::move(op)), stack_(args) {}
|
||||
|
||||
std::shared_ptr<Operator> ScriptCall::op() const {
|
||||
return *op_;
|
||||
}
|
||||
|
||||
const std::vector<at::IValue>& ScriptCall::stack() const {
|
||||
return stack_;
|
||||
}
|
||||
|
||||
Message ScriptCall::toMessage() {
|
||||
std::vector<torch::Tensor> tensor_table;
|
||||
Pickler pickler(&tensor_table);
|
||||
|
||||
pickler.protocol();
|
||||
pickler.startTuple();
|
||||
for (auto& value: stack_) {
|
||||
pickler.pushIValue(value);
|
||||
}
|
||||
if (op_) {
|
||||
// builtin ops
|
||||
|
||||
// TODO: replace this with a real overload_name when FunctionSchema supports
|
||||
// that.
|
||||
pickler.pushIValue(toString((*op_)->schema()));
|
||||
// insert qualified name
|
||||
auto opName = (*op_)->schema().name();
|
||||
TORCH_CHECK(opName.find("::") == opName.rfind("::") &&
|
||||
opName.rfind(ATEN_PREFIX_) == 0,
|
||||
"Unexpected operator name ", opName);
|
||||
// aten::add -> torch.ops.aten.add
|
||||
opName.replace(0, ATEN_PREFIX_.length(), BUILTIN_OP_NAMESPACE_);
|
||||
pickler.pushIValue(opName);
|
||||
}
|
||||
pickler.endTuple();
|
||||
pickler.stop();
|
||||
|
||||
auto payload = pickler.stack();
|
||||
return Message(std::move(payload),
|
||||
std::move(tensor_table),
|
||||
MessageType::SCRIPT_CALL);
|
||||
}
|
||||
|
||||
ScriptCall ScriptCall::fromMessage(const Message& message) {
|
||||
auto payload = static_cast<const void*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
Unpickler unpickler(payload, payload_size, &message.tensors(), nullptr);
|
||||
|
||||
std::vector<IValue> values = unpickler.parse_ivalue_list();
|
||||
|
||||
TORCH_CHECK(values.size() >= 1, "Message of a ScriptCall must at least "
|
||||
"contain one IValue as the operator schema.");
|
||||
|
||||
const std::string& qualifiedName = values.back().toStringRef();
|
||||
if (qualifiedName.rfind(BUILTIN_OP_NAMESPACE_) == 0) {
|
||||
values.pop_back();
|
||||
|
||||
const std::string& str_schema = values.back().toStringRef();
|
||||
// extract symbol from the schema
|
||||
auto str_symbol = str_schema.substr(0, str_schema.find('('));
|
||||
auto symbol = at::Symbol::fromQualString(str_symbol);
|
||||
auto op = matchOperator(symbol, str_schema);
|
||||
// remove str_schema from values
|
||||
values.pop_back();
|
||||
|
||||
return ScriptCall(op, std::move(values));
|
||||
} else {
|
||||
AT_ERROR("Unrecognized qualified name ", qualifiedName);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Operator> ScriptCall::matchOperator(
|
||||
at::Symbol& symbol, const std::string& str_schema) {
|
||||
// TODO: This is a temporary solution. We should pass enough information to
|
||||
// allow deterministically matched to one operator.
|
||||
for (auto op: torch::jit::getAllOperatorsFor(symbol)) {
|
||||
if (toString(op->schema()).compare(str_schema) == 0) {
|
||||
return op;
|
||||
}
|
||||
}
|
||||
AT_ERROR("Cannot find matching operator for schema ", str_schema);
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
||||
46
torch/csrc/distributed/rpc/ScriptCall.h
Normal file
46
torch/csrc/distributed/rpc/ScriptCall.h
Normal file
@ -0,0 +1,46 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/csrc/distributed/rpc/Message.h>
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
using torch::jit::Operator;
|
||||
|
||||
// A ScriptCall instance represents an invocation of a builtin operator for a
|
||||
// TorchScript function (not implemented yet). If it is a builtin operator, it
|
||||
// contains a shared ptr to the `Operator` and a list of arguments.
|
||||
class TORCH_API ScriptCall final {
|
||||
public:
|
||||
ScriptCall(std::shared_ptr<Operator> op, std::vector<at::IValue>&& args);
|
||||
|
||||
std::shared_ptr<Operator> op() const;
|
||||
// return the argument stack of this builtin operator
|
||||
const std::vector<at::IValue>& stack() const;
|
||||
|
||||
Message toMessage();
|
||||
static ScriptCall fromMessage(const Message& message);
|
||||
|
||||
private:
|
||||
|
||||
// Given an operator symbol and a string schema, return the matched operator.
|
||||
static std::shared_ptr<Operator> matchOperator(
|
||||
at::Symbol& symbol, const std::string& str_schema);
|
||||
|
||||
static const std::string BUILTIN_OP_NAMESPACE_;
|
||||
static const std::string ATEN_PREFIX_;
|
||||
|
||||
// This field has value if this ScriptCall represents invocation of a builtin
|
||||
// operator.
|
||||
c10::optional<std::shared_ptr<Operator>> op_;
|
||||
const std::vector<at::IValue> stack_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
50
torch/csrc/distributed/rpc/ScriptRet.cpp
Normal file
50
torch/csrc/distributed/rpc/ScriptRet.cpp
Normal file
@ -0,0 +1,50 @@
|
||||
#include <torch/csrc/distributed/rpc/ScriptRet.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
namespace {
|
||||
|
||||
using torch::jit::Pickler;
|
||||
using torch::jit::Unpickler;
|
||||
|
||||
} // namespace
|
||||
|
||||
ScriptRet::ScriptRet(at::IValue&& value) : value_(value) {}
|
||||
|
||||
const at::IValue& ScriptRet::value() {
|
||||
return value_;
|
||||
}
|
||||
|
||||
Message ScriptRet::toMessage() {
|
||||
std::vector<torch::Tensor> tensor_table;
|
||||
Pickler pickler(&tensor_table);
|
||||
|
||||
pickler.protocol();
|
||||
pickler.startTuple();
|
||||
pickler.pushIValue(value_);
|
||||
pickler.endTuple();
|
||||
pickler.stop();
|
||||
|
||||
auto payload = pickler.stack();
|
||||
return Message(std::move(payload),
|
||||
std::move(tensor_table),
|
||||
MessageType::SCRIPT_RET);
|
||||
}
|
||||
|
||||
ScriptRet ScriptRet::fromMessage(const Message& message) {
|
||||
auto payload = static_cast<const void*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
Unpickler unpickler(payload, payload_size, &message.tensors(), nullptr);
|
||||
|
||||
auto values = unpickler.parse_ivalue_list();
|
||||
AT_ASSERT(values.size() == 1, "Return value of a builtin operator or a "
|
||||
"TorchScript function should be a single IValue, got a vector of size ",
|
||||
values.size());
|
||||
return ScriptRet(std::move(values.front()));
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
||||
25
torch/csrc/distributed/rpc/ScriptRet.h
Normal file
25
torch/csrc/distributed/rpc/ScriptRet.h
Normal file
@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/rpc/Message.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
// Return value of a builtin operator or a TorchScript function.
|
||||
class TORCH_API ScriptRet final {
|
||||
public:
|
||||
explicit ScriptRet(at::IValue&& values);
|
||||
|
||||
const at::IValue& value();
|
||||
Message toMessage();
|
||||
static ScriptRet fromMessage(const Message& message);
|
||||
|
||||
private:
|
||||
const at::IValue value_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
32
torch/csrc/distributed/rpc/functions.cpp
Normal file
32
torch/csrc/distributed/rpc/functions.cpp
Normal file
@ -0,0 +1,32 @@
|
||||
#include <torch/csrc/distributed/rpc/functions.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
void processRequestBlocking(
|
||||
const std::string& from, Message&& request, RpcAgent& agent) {
|
||||
switch (request.type()) {
|
||||
case MessageType::SCRIPT_CALL: {
|
||||
ScriptCall op = ScriptCall::fromMessage(request);
|
||||
|
||||
auto stack = op.stack();
|
||||
op.op()->getOperation()(stack);
|
||||
AT_ASSERT(stack.size() == 1, "Return value of a builtin operator or a "
|
||||
"TorchScript function should be a single IValue, got a vector of "
|
||||
"size ", stack.size());
|
||||
|
||||
auto response = ScriptRet(std::move(stack.front())).toMessage();
|
||||
response.setId(request.id());
|
||||
agent.send(from, std::move(response));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
AT_ERROR("Request type ", request.type(), " not supported.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
18
torch/csrc/distributed/rpc/functions.h
Normal file
18
torch/csrc/distributed/rpc/functions.h
Normal file
@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/rpc/FutureMessage.h>
|
||||
#include <torch/csrc/distributed/rpc/Message.h>
|
||||
#include <torch/csrc/distributed/rpc/RpcAgent.h>
|
||||
#include <torch/csrc/distributed/rpc/ScriptCall.h>
|
||||
#include <torch/csrc/distributed/rpc/ScriptRet.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
void processRequestBlocking(
|
||||
const std::string& from, Message&& message, RpcAgent& agent);
|
||||
|
||||
} // rpc
|
||||
} // distributed
|
||||
} // torch
|
||||
83
torch/csrc/distributed/rpc/init.cpp
Normal file
83
torch/csrc/distributed/rpc/init.cpp
Normal file
@ -0,0 +1,83 @@
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
#include <torch/csrc/distributed/rpc/FutureMessage.h>
|
||||
#include <torch/csrc/distributed/rpc/ProcessGroupAgent.h>
|
||||
#include <torch/csrc/distributed/rpc/RpcAgent.h>
|
||||
#include <torch/csrc/distributed/rpc/functions.h>
|
||||
#include <torch/csrc/distributed/rpc/python_functions.h>
|
||||
#include <torch/csrc/jit/pybind_utils.h>
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
|
||||
|
||||
PyObject* rpc_init(PyObject* /* unused */) {
|
||||
auto dist_module = THPObjectPtr(PyImport_ImportModule("torch.distributed"));
|
||||
if (!dist_module) {
|
||||
throw python_error();
|
||||
}
|
||||
|
||||
auto module = py::handle(dist_module).cast<py::module>();
|
||||
|
||||
auto rpcAgent = shared_ptr_class_<RpcAgent>(module, "RpcAgent")
|
||||
.def("join",
|
||||
&RpcAgent::join,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("sync",
|
||||
&RpcAgent::sync,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
auto futureMessage = shared_ptr_class_<FutureMessage>(module, "FutureMessage")
|
||||
.def("wait",
|
||||
[&](FutureMessage& fut) {
|
||||
return to_py_obj(fut.wait());
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
auto processGroupAgent =
|
||||
shared_ptr_class_<ProcessGroupAgent>(
|
||||
module, "ProcessGroupAgent", rpcAgent)
|
||||
.def(py::init<std::string,
|
||||
std::unordered_map<std::string, int>,
|
||||
std::shared_ptr<::c10d::ProcessGroup>>())
|
||||
.def("join",
|
||||
&ProcessGroupAgent::join,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("sync",
|
||||
&ProcessGroupAgent::sync,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
module.def("invoke_rpc", [](
|
||||
RpcAgent& agent,
|
||||
const std::string& dstName,
|
||||
const std::string& opName,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs) {
|
||||
return py_rpc(agent, dstName, opName, args, kwargs);
|
||||
});
|
||||
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
static PyMethodDef methods[] = { // NOLINT
|
||||
{"_rpc_init", (PyCFunction)rpc_init, METH_NOARGS, nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}};
|
||||
|
||||
PyMethodDef* python_functions() {
|
||||
return methods;
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
||||
49
torch/csrc/distributed/rpc/python_functions.cpp
Normal file
49
torch/csrc/distributed/rpc/python_functions.cpp
Normal file
@ -0,0 +1,49 @@
|
||||
#include <torch/csrc/distributed/rpc/python_functions.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
py::object to_py_obj(const Message& message) {
|
||||
switch (message.type()) {
|
||||
case MessageType::SCRIPT_RET: {
|
||||
ScriptRet ret = ScriptRet::fromMessage(message);
|
||||
Stack stack;
|
||||
stack.push_back(ret.value());
|
||||
return torch::jit::createPyObjectForStack(std::move(stack));
|
||||
}
|
||||
default: {
|
||||
AT_ERROR("Unrecognized response message type ", message.type());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<FutureMessage> py_rpc(
|
||||
RpcAgent& agent,
|
||||
const std::string& dstName,
|
||||
const std::string& opName,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs) {
|
||||
if (opName.rfind("aten", 0) == 0) {
|
||||
// builtin operators.
|
||||
Symbol symbol = Symbol::fromQualString(opName);
|
||||
for (const auto& op: torch::jit::getAllOperatorsFor(symbol)) {
|
||||
try {
|
||||
// FIXME: This is temporary solution. We should at least refactor
|
||||
// ``createStackForSchema`` to avoid throwing an error.
|
||||
Stack stack = torch::jit::createStackForSchema(
|
||||
op->schema(), args, kwargs, c10::nullopt);
|
||||
|
||||
return agent.send(
|
||||
dstName, ScriptCall(op, std::move(stack)).toMessage());
|
||||
} catch (std::runtime_error) {}
|
||||
}
|
||||
}
|
||||
|
||||
AT_ERROR("Failed to match operator name ", opName, " and arguments "
|
||||
"(args: ", args, ", kwargs: ", kwargs, ") to a builtin operator");
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
28
torch/csrc/distributed/rpc/python_functions.h
Normal file
28
torch/csrc/distributed/rpc/python_functions.h
Normal file
@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
|
||||
#include <torch/csrc/distributed/rpc/FutureMessage.h>
|
||||
#include <torch/csrc/distributed/rpc/Message.h>
|
||||
#include <torch/csrc/distributed/rpc/RpcAgent.h>
|
||||
#include <torch/csrc/distributed/rpc/ScriptCall.h>
|
||||
#include <torch/csrc/distributed/rpc/ScriptRet.h>
|
||||
#include <torch/csrc/jit/pybind_utils.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
py::object to_py_obj(const Message& message);
|
||||
|
||||
std::shared_ptr<FutureMessage> py_rpc(
|
||||
RpcAgent& agent,
|
||||
const std::string& dstName,
|
||||
const std::string& opName,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
13
torch/csrc/distributed/rpc/rpc.h
Normal file
13
torch/csrc/distributed/rpc/rpc.h
Normal file
@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
PyMethodDef* python_functions();
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
||||
@ -2,10 +2,10 @@ import torch
|
||||
|
||||
|
||||
def is_available():
|
||||
return hasattr(torch._C, "_c10d_init")
|
||||
return hasattr(torch._C, "_c10d_init") and hasattr(torch._C, "_rpc_init")
|
||||
|
||||
|
||||
if is_available() and not torch._C._c10d_init():
|
||||
if is_available() and not (torch._C._c10d_init() and torch._C._rpc_init()):
|
||||
raise RuntimeError("Failed to initialize PyTorch distributed support")
|
||||
|
||||
|
||||
@ -15,3 +15,4 @@ if is_available():
|
||||
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
|
||||
# this.
|
||||
from .distributed_c10d import _backend # noqa: F401
|
||||
from .rpc import * # noqa: F401
|
||||
|
||||
174
torch/distributed/rpc.py
Normal file
174
torch/distributed/rpc.py
Normal file
@ -0,0 +1,174 @@
|
||||
from . import invoke_rpc
|
||||
from . import ProcessGroupAgent
|
||||
|
||||
import array
|
||||
import sys
|
||||
import torch
|
||||
|
||||
|
||||
_agent = None
|
||||
|
||||
|
||||
def _collect_worker_names(name, group):
|
||||
from . import all_gather
|
||||
from . import get_world_size
|
||||
|
||||
# collect name length
|
||||
ws = get_world_size(group)
|
||||
name_bytes = name if sys.version_info < (3, 0) else bytes(name, 'utf8')
|
||||
name_bytes = list(array.array('B', name_bytes))
|
||||
name_len = len(name_bytes)
|
||||
len_input = torch.ones(1, dtype=torch.int64) * name_len
|
||||
len_outputs = [torch.empty(1, dtype=torch.int64) for _ in range(ws)]
|
||||
all_gather(len_outputs, len_input, group=group)
|
||||
|
||||
# collect name value
|
||||
max_len = torch.stack(len_outputs).max().item()
|
||||
name_input = torch.empty(max_len, dtype=torch.uint8)
|
||||
name_input[:name_len] = torch.tensor(name_bytes, dtype=torch.uint8)
|
||||
name_outputs = [torch.empty(max_len, dtype=torch.uint8) for _ in range(ws)]
|
||||
all_gather(name_outputs, name_input, group=group)
|
||||
|
||||
names = []
|
||||
for i in range(ws):
|
||||
name_tensor = name_outputs[i][:len_outputs[i]]
|
||||
names.append(bytearray(name_tensor.tolist()).decode('utf8'))
|
||||
|
||||
return names
|
||||
|
||||
|
||||
def join_rpc():
|
||||
r"""
|
||||
Block until all local and remote RPC processes reach this method, process
|
||||
(send and receive) all pending messages, and then destroy local RPC agent.
|
||||
Every RPC process must call this method before exit.
|
||||
"""
|
||||
global _agent
|
||||
|
||||
if _agent:
|
||||
_agent.join()
|
||||
_agent = None
|
||||
|
||||
|
||||
def sync_rpc():
|
||||
r"""
|
||||
Block until all local and remote RPC processes reach this method and finish
|
||||
sending all pending RPCs. As this method synchronizes at the process
|
||||
level, if multiple threads are spawned, only one of them should call this
|
||||
method at a time.
|
||||
"""
|
||||
if _agent is None:
|
||||
raise RuntimeError("RPC has not been initialized. "
|
||||
"Call init_rpc(name) first.")
|
||||
|
||||
_agent.sync()
|
||||
|
||||
|
||||
# TODO: add a context managet to wrap init_rpc and join_rpc
|
||||
def init_rpc(name, backend='pg'):
|
||||
r"""
|
||||
Initialize the local RPC agent which immediately makes the current process
|
||||
ready to send and receive RPCs. The caller needs to make sure the specified
|
||||
backend is properly intialized before calling this method. For example, to
|
||||
use ``pg`` (ProcessGroup) backend, ``init_process_group`` must be invoked
|
||||
prior to this method.
|
||||
|
||||
Arguments:
|
||||
name (str): a globally unique name of the local RPC agent. (e.g.,
|
||||
``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``)
|
||||
backend (str): type of RPC backend implementation. Currently,
|
||||
process group backend ``"pg"`` is the only available
|
||||
backend implementation. (default: ``"pg"``).
|
||||
"""
|
||||
global _agent
|
||||
|
||||
if _agent:
|
||||
raise RuntimeError("RPC is already initialized")
|
||||
|
||||
if backend == 'pg':
|
||||
from .distributed_c10d import _get_default_group
|
||||
group = _get_default_group()
|
||||
# TODO: issue #23232
|
||||
names = _collect_worker_names(name, group)
|
||||
name_dict = {names[r] : r for r in range(len(names))}
|
||||
_agent = ProcessGroupAgent(name, name_dict, group)
|
||||
else:
|
||||
raise RuntimeError("Unrecognized RPC backend ", backend)
|
||||
|
||||
|
||||
def rpc(to, func, args=None, kwargs=None, async_call=False):
|
||||
r"""
|
||||
Make an RPC call to run function ``func`` on worker ``to``. By default, it
|
||||
blocks until the return value is locally available. RPC messages are sent
|
||||
and received in parallel to execution of Python code. This method is
|
||||
thread-safe.
|
||||
|
||||
Arguments:
|
||||
to (str): name of the destination worker.
|
||||
func (callable): a builtin function (e.g., ``torch.add``).
|
||||
args (tuple): the argument tuple for the ``func`` invocation.
|
||||
kwargs (dict): is a dictionary of keyword arguments for the ``func``
|
||||
invocation.
|
||||
async_call (bool): If set to ``True``, this will be an asynchronous RPC,
|
||||
and returns a ``torch.distributed.FutureMessage``
|
||||
object immediately. Otherwise, this RPC will block
|
||||
until the return value is locally available.
|
||||
(default: ``False``)
|
||||
|
||||
Returns:
|
||||
If ``async_call`` is ``False``, returns the result of running ``func``
|
||||
on ``args`` and ``kwargs``. If ``async_call`` is ``True``, returns a
|
||||
``torch.distributed.FutureMessage`` object that can be waited on. When
|
||||
completed, the return value of ``func`` on ``args`` and ``kwargs`` can
|
||||
be retrieved from the ``FutureMessage`` object.
|
||||
|
||||
Example::
|
||||
|
||||
Synchronous example:
|
||||
|
||||
On worker 0:
|
||||
>>> import torch.distributed as dist
|
||||
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
|
||||
>>> dist.init_rpc("worker0")
|
||||
>>> ret = dist.rpc("worker1", torch.add, args=(torch.ones(2), 3))
|
||||
>>> dist.join_rpc()
|
||||
|
||||
One worker 1:
|
||||
>>> import torch.distributed as dist
|
||||
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
|
||||
>>> dist.init_rpc("worker1")
|
||||
>>> dist.join_rpc()
|
||||
|
||||
Asynchronous example:
|
||||
|
||||
On worker 0:
|
||||
>>> import torch.distributed as dist
|
||||
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
|
||||
>>> dist.init_rpc("worker0")
|
||||
>>> fut1 = dist.rpc("worker1", torch.add, args=(torch.ones(2), 3), async_call=True)
|
||||
>>> fut2 = dist.rpc("worker1", torch.add, args=(torch.ones(2), 2), async_call=True)
|
||||
>>> result = fut1.wait() + fut2.wait()
|
||||
>>> dist.join_rpc()
|
||||
|
||||
One worker 1:
|
||||
>>> import torch.distributed as dist
|
||||
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
|
||||
>>> dist.init_rpc("worker1")
|
||||
>>> dist.join_rpc()
|
||||
"""
|
||||
if _agent is None:
|
||||
raise RuntimeError("RPC has not been initialized. "
|
||||
"Call init_rpc(name) first.")
|
||||
|
||||
qualified_name = torch.jit._find_builtin(func)
|
||||
if qualified_name is None:
|
||||
raise RuntimeError("unknown builtin function %s." % func)
|
||||
|
||||
args = args if args else ()
|
||||
kwargs = kwargs if kwargs else {}
|
||||
fut = invoke_rpc(_agent, to, qualified_name, *args, **kwargs)
|
||||
|
||||
if async_call:
|
||||
return fut
|
||||
else:
|
||||
return fut.wait()
|
||||
Reference in New Issue
Block a user