Let RpcAgent::send() return JitFuture (#49906)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49906

This commit modifies RPC Message to inherit from `torch::CustomClassHolder`,
and wraps a Message in an IValue in `RpcAgent::send()`.

Test Plan: Imported from OSS

Reviewed By: lw

Differential Revision: D25719518

Pulled By: mrshenli

fbshipit-source-id: 694e40021e49e396da1620a2f81226522341550b
This commit is contained in:
Shen Li
2021-01-07 19:43:44 -08:00
committed by Facebook GitHub Bot
parent 4de6b279c8
commit 84e3237a53
13 changed files with 78 additions and 22 deletions

View File

@ -47,8 +47,8 @@ variable_list RecvRpcBackward::apply(variable_list&& grads) {
// Send the gradients over to the appropriate node.
auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
auto futureMessage = rpcAgent->send(
rpcAgent->getWorkerInfo(fromWorkerId_), std::move(gradCall).toMessage());
auto futureMessage = rpc::RpcAgent::toFutureMessage(rpcAgent->send(
rpcAgent->getWorkerInfo(fromWorkerId_), std::move(gradCall).toMessage()));
// Record the future in the context.
sharedContext->addOutstandingRpc(futureMessage);

View File

@ -160,9 +160,11 @@ std::shared_ptr<FutureMessage> sendMessageWithAutograd(
std::move(msg),
rpc::MessageType::RUN_WITH_PROFILING_REQ,
std::move(profilerConfig));
fut = agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
fut = rpc::RpcAgent::toFutureMessage(
agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds));
} else {
fut = agent.send(dst, std::move(msg), rpcTimeoutSeconds);
fut = rpc::RpcAgent::toFutureMessage(
agent.send(dst, std::move(msg), rpcTimeoutSeconds));
}
return fut;

View File

@ -104,6 +104,18 @@ Message createExceptionResponse(const std::string& exceptionStr, int64_t id) {
id);
}
namespace {
// NB: need to call torch::class_ to register Message in the map returned by
// c10::getCustomClassTypeMap(). Otherwise, Message cannot be wrapped within
// an IValue.
// NB: add this line here instead of in rpc/init.cpp because 1) we have C++
// only tests that won't run rpc/init.cpp; 2) Message is not meant to be
// visible from Python.
static const auto message = torch::class_<Message>("rpc", "_Message");
} // namespace
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -93,7 +93,7 @@ enum MessageType {
// Layers above ``RpcAgent`` only converts ScriptCall, ScriptResp, PythonCall,
// and PythonResp into a Message, and it is up to the RpcAgent
// implementation to determine how to serialize a message.
class TORCH_API Message final {
class TORCH_API Message final : public torch::CustomClassHolder {
public:
Message();

View File

@ -287,7 +287,7 @@ void ProcessGroupAgent::shutdownImpl() {
threadPool_.waitWorkComplete();
}
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
std::shared_ptr<JitFuture> ProcessGroupAgent::send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds) {
@ -369,7 +369,7 @@ std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
// to our receiving queue.
if (to.id_ == (worker_id_t)pg_->getRank()) {
sendToSelf(std::move(message));
return future;
return toJitFuture(std::move(future));
}
// NB: cannot directly pass ``to`` to the ``SendWork``, because it might no
@ -382,7 +382,9 @@ std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
// the C++ land. Hence, we have to explicitly use the ``WorkerInfo`` in the
// C++ land.
enqueueSend(SendWork(allWorkerInfo_[to.id_], std::move(message)));
return future;
auto jitFuture = toJitFuture(std::move(future));
return jitFuture;
}
void ProcessGroupAgent::handleSend(const SendWork& work) {

View File

@ -88,7 +88,7 @@ class TORCH_API ProcessGroupAgent : public RpcAgent {
// This method wraps the destination information and the message into a
// SendWork object, and put the SendWork into a queue. Another thread will
// consume SendWork from the queue and send it out.
std::shared_ptr<FutureMessage> send(
std::shared_ptr<JitFuture> send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout) override;

View File

@ -331,10 +331,10 @@ void PyRRef::backward(
// Invoke distributed backward remotely.
auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
rpcAgent
->send(
rpc::RpcAgent::toFutureMessage(
rpcAgent->send(
rpcAgent->getWorkerInfo(rref->owner()),
std::move(rrefBackwardReq).toMessage())
std::move(rrefBackwardReq).toMessage()))
->wait();
}
}

View File

@ -62,7 +62,7 @@ std::shared_ptr<FutureMessage> RpcAgent::sendWithRetries(
computeNewRpcRetryTime(retryOptions, /* retryCount */ 0);
// Making a copy of the message so it can be retried after the first send.
Message msgCopy = message;
auto fm = send(to, std::move(message));
auto fm = toFutureMessage(send(to, std::move(message)));
auto firstRetryRpc = std::make_shared<RpcRetryInfo>(
to,
std::move(msgCopy),
@ -133,7 +133,7 @@ void RpcAgent::retryExpiredRpcs() {
// with an error, since this RPC never succeeded and can no longer be
// retried.
try {
fm = send(earliestRpc->to_, std::move(msgCopy));
fm = toFutureMessage(send(earliestRpc->to_, std::move(msgCopy)));
futures.emplace_back(fm, earliestRpc);
} catch (std::exception& e) {
// We must store the futures and exception messages here and only mark

View File

@ -157,7 +157,7 @@ class TORCH_API RpcAgent {
// If ``message.isRequest()`` is true, the ``FutureMessage`` will be
// completed when the response arrives. For other message types, the Future
// should be ignored by the caller.
virtual std::shared_ptr<FutureMessage> send(
virtual std::shared_ptr<JitFuture> send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout) = 0;
@ -259,6 +259,46 @@ class TORCH_API RpcAgent {
// Get the type resolver
std::shared_ptr<TypeResolver> getTypeResolver();
static std::shared_ptr<JitFuture> toJitFuture(
std::shared_ptr<FutureMessage>&& fm) {
auto jitFuture = std::make_shared<JitFuture>(at::AnyClassType::get());
std::weak_ptr<FutureMessage> wp = fm;
fm->addCallback(
[jitFuture, wp]() mutable {
auto future = wp.lock();
TORCH_INTERNAL_ASSERT(future);
if (future->hasError()) {
jitFuture->setError(std::make_exception_ptr(*(future->error())));
} else {
jitFuture->markCompleted(IValue(
c10::make_intrusive<Message>(std::move(*future).moveValue())));
}
}
);
return jitFuture;
}
static std::shared_ptr<FutureMessage> toFutureMessage(
std::shared_ptr<JitFuture>&& jitFuture) {
auto fm = std::make_shared<FutureMessage>();
std::weak_ptr<JitFuture> wp = jitFuture;
jitFuture->addCallback(
[fm, wp]() mutable {
auto future = wp.lock();
TORCH_INTERNAL_ASSERT(future);
if (future->hasError()) {
fm->setError(future->tryRetrieveErrorMessage());
} else {
fm->markCompleted(
std::move(*future->value().toCustomClass<Message>()));
}
}
);
return fm;
}
protected:
const WorkerInfo workerInfo_;
const std::unique_ptr<RequestCallback> cb_;

View File

@ -604,7 +604,7 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
});
}
std::shared_ptr<FutureMessage> TensorPipeAgent::send(
std::shared_ptr<JitFuture> TensorPipeAgent::send(
const WorkerInfo& toWorkerInfo,
Message&& requestMessage,
const float rpcTimeoutSeconds) {
@ -778,8 +778,8 @@ std::shared_ptr<FutureMessage> TensorPipeAgent::send(
});
});
return std::shared_ptr<FutureMessage>(
futureResponseMessage, &futureResponseMessage->futMsg);
return toJitFuture(std::shared_ptr<FutureMessage>(
futureResponseMessage, &futureResponseMessage->futMsg));
}
void TensorPipeAgent::pollTimeoutRpcs() {

View File

@ -181,7 +181,7 @@ class TensorPipeAgent : public RpcAgent {
TensorPipeAgent(const TensorPipeAgent&) = delete;
TensorPipeAgent& operator=(const TensorPipeAgent&) = delete;
std::shared_ptr<FutureMessage> send(
std::shared_ptr<JitFuture> send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout) override;

View File

@ -56,7 +56,7 @@ std::unordered_map<MessageType, float, std::hash<int>> FaultyProcessGroupAgent::
return delayMessages;
}
std::shared_ptr<FutureMessage> FaultyProcessGroupAgent::send(
std::shared_ptr<JitFuture> FaultyProcessGroupAgent::send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds) {
@ -82,7 +82,7 @@ std::shared_ptr<FutureMessage> FaultyProcessGroupAgent::send(
fm->setError(makeRPCError(
c10::str("Send attempt failed intentionally for ", key),
RPCErrorType::INTENTIONAL_FAILURE));
return fm;
return toJitFuture(std::move(fm));
} else {
lock.unlock();
return ProcessGroupAgent::send(to, std::move(message), rpcTimeoutSeconds);

View File

@ -43,7 +43,7 @@ class FaultyProcessGroupAgent : public ProcessGroupAgent {
int failNumSends = 0);
// Faulty send function for this class.
std::shared_ptr<FutureMessage> send(
std::shared_ptr<JitFuture> send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds =