mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Let RpcAgent::send() return JitFuture (#49906)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49906 This commit modifies RPC Message to inherit from `torch::CustomClassHolder`, and wraps a Message in an IValue in `RpcAgent::send()`. Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25719518 Pulled By: mrshenli fbshipit-source-id: 694e40021e49e396da1620a2f81226522341550b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4de6b279c8
commit
84e3237a53
@ -47,8 +47,8 @@ variable_list RecvRpcBackward::apply(variable_list&& grads) {
|
||||
|
||||
// Send the gradients over to the appropriate node.
|
||||
auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
|
||||
auto futureMessage = rpcAgent->send(
|
||||
rpcAgent->getWorkerInfo(fromWorkerId_), std::move(gradCall).toMessage());
|
||||
auto futureMessage = rpc::RpcAgent::toFutureMessage(rpcAgent->send(
|
||||
rpcAgent->getWorkerInfo(fromWorkerId_), std::move(gradCall).toMessage()));
|
||||
|
||||
// Record the future in the context.
|
||||
sharedContext->addOutstandingRpc(futureMessage);
|
||||
|
@ -160,9 +160,11 @@ std::shared_ptr<FutureMessage> sendMessageWithAutograd(
|
||||
std::move(msg),
|
||||
rpc::MessageType::RUN_WITH_PROFILING_REQ,
|
||||
std::move(profilerConfig));
|
||||
fut = agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
|
||||
fut = rpc::RpcAgent::toFutureMessage(
|
||||
agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds));
|
||||
} else {
|
||||
fut = agent.send(dst, std::move(msg), rpcTimeoutSeconds);
|
||||
fut = rpc::RpcAgent::toFutureMessage(
|
||||
agent.send(dst, std::move(msg), rpcTimeoutSeconds));
|
||||
}
|
||||
|
||||
return fut;
|
||||
|
@ -104,6 +104,18 @@ Message createExceptionResponse(const std::string& exceptionStr, int64_t id) {
|
||||
id);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// NB: need to call torch::class_ to register Message in the map returned by
|
||||
// c10::getCustomClassTypeMap(). Otherwise, Message cannot be wrapped within
|
||||
// an IValue.
|
||||
// NB: add this line here instead of in rpc/init.cpp because 1) we have C++
|
||||
// only tests that won't run rpc/init.cpp; 2) Message is not meant to be
|
||||
// visible from Python.
|
||||
static const auto message = torch::class_<Message>("rpc", "_Message");
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
||||
|
@ -93,7 +93,7 @@ enum MessageType {
|
||||
// Layers above ``RpcAgent`` only converts ScriptCall, ScriptResp, PythonCall,
|
||||
// and PythonResp into a Message, and it is up to the RpcAgent
|
||||
// implementation to determine how to serialize a message.
|
||||
class TORCH_API Message final {
|
||||
class TORCH_API Message final : public torch::CustomClassHolder {
|
||||
public:
|
||||
Message();
|
||||
|
||||
|
@ -287,7 +287,7 @@ void ProcessGroupAgent::shutdownImpl() {
|
||||
threadPool_.waitWorkComplete();
|
||||
}
|
||||
|
||||
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
|
||||
std::shared_ptr<JitFuture> ProcessGroupAgent::send(
|
||||
const WorkerInfo& to,
|
||||
Message&& message,
|
||||
const float rpcTimeoutSeconds) {
|
||||
@ -369,7 +369,7 @@ std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
|
||||
// to our receiving queue.
|
||||
if (to.id_ == (worker_id_t)pg_->getRank()) {
|
||||
sendToSelf(std::move(message));
|
||||
return future;
|
||||
return toJitFuture(std::move(future));
|
||||
}
|
||||
|
||||
// NB: cannot directly pass ``to`` to the ``SendWork``, because it might no
|
||||
@ -382,7 +382,9 @@ std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
|
||||
// the C++ land. Hence, we have to explicitly use the ``WorkerInfo`` in the
|
||||
// C++ land.
|
||||
enqueueSend(SendWork(allWorkerInfo_[to.id_], std::move(message)));
|
||||
return future;
|
||||
|
||||
auto jitFuture = toJitFuture(std::move(future));
|
||||
return jitFuture;
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::handleSend(const SendWork& work) {
|
||||
|
@ -88,7 +88,7 @@ class TORCH_API ProcessGroupAgent : public RpcAgent {
|
||||
// This method wraps the destination information and the message into a
|
||||
// SendWork object, and put the SendWork into a queue. Another thread will
|
||||
// consume SendWork from the queue and send it out.
|
||||
std::shared_ptr<FutureMessage> send(
|
||||
std::shared_ptr<JitFuture> send(
|
||||
const WorkerInfo& to,
|
||||
Message&& message,
|
||||
const float rpcTimeoutSeconds = kUnsetRpcTimeout) override;
|
||||
|
@ -331,10 +331,10 @@ void PyRRef::backward(
|
||||
|
||||
// Invoke distributed backward remotely.
|
||||
auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
|
||||
rpcAgent
|
||||
->send(
|
||||
rpc::RpcAgent::toFutureMessage(
|
||||
rpcAgent->send(
|
||||
rpcAgent->getWorkerInfo(rref->owner()),
|
||||
std::move(rrefBackwardReq).toMessage())
|
||||
std::move(rrefBackwardReq).toMessage()))
|
||||
->wait();
|
||||
}
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ std::shared_ptr<FutureMessage> RpcAgent::sendWithRetries(
|
||||
computeNewRpcRetryTime(retryOptions, /* retryCount */ 0);
|
||||
// Making a copy of the message so it can be retried after the first send.
|
||||
Message msgCopy = message;
|
||||
auto fm = send(to, std::move(message));
|
||||
auto fm = toFutureMessage(send(to, std::move(message)));
|
||||
auto firstRetryRpc = std::make_shared<RpcRetryInfo>(
|
||||
to,
|
||||
std::move(msgCopy),
|
||||
@ -133,7 +133,7 @@ void RpcAgent::retryExpiredRpcs() {
|
||||
// with an error, since this RPC never succeeded and can no longer be
|
||||
// retried.
|
||||
try {
|
||||
fm = send(earliestRpc->to_, std::move(msgCopy));
|
||||
fm = toFutureMessage(send(earliestRpc->to_, std::move(msgCopy)));
|
||||
futures.emplace_back(fm, earliestRpc);
|
||||
} catch (std::exception& e) {
|
||||
// We must store the futures and exception messages here and only mark
|
||||
|
@ -157,7 +157,7 @@ class TORCH_API RpcAgent {
|
||||
// If ``message.isRequest()`` is true, the ``FutureMessage`` will be
|
||||
// completed when the response arrives. For other message types, the Future
|
||||
// should be ignored by the caller.
|
||||
virtual std::shared_ptr<FutureMessage> send(
|
||||
virtual std::shared_ptr<JitFuture> send(
|
||||
const WorkerInfo& to,
|
||||
Message&& message,
|
||||
const float rpcTimeoutSeconds = kUnsetRpcTimeout) = 0;
|
||||
@ -259,6 +259,46 @@ class TORCH_API RpcAgent {
|
||||
// Get the type resolver
|
||||
std::shared_ptr<TypeResolver> getTypeResolver();
|
||||
|
||||
static std::shared_ptr<JitFuture> toJitFuture(
|
||||
std::shared_ptr<FutureMessage>&& fm) {
|
||||
auto jitFuture = std::make_shared<JitFuture>(at::AnyClassType::get());
|
||||
|
||||
std::weak_ptr<FutureMessage> wp = fm;
|
||||
fm->addCallback(
|
||||
[jitFuture, wp]() mutable {
|
||||
auto future = wp.lock();
|
||||
TORCH_INTERNAL_ASSERT(future);
|
||||
if (future->hasError()) {
|
||||
jitFuture->setError(std::make_exception_ptr(*(future->error())));
|
||||
} else {
|
||||
jitFuture->markCompleted(IValue(
|
||||
c10::make_intrusive<Message>(std::move(*future).moveValue())));
|
||||
}
|
||||
}
|
||||
);
|
||||
return jitFuture;
|
||||
}
|
||||
|
||||
static std::shared_ptr<FutureMessage> toFutureMessage(
|
||||
std::shared_ptr<JitFuture>&& jitFuture) {
|
||||
auto fm = std::make_shared<FutureMessage>();
|
||||
|
||||
std::weak_ptr<JitFuture> wp = jitFuture;
|
||||
jitFuture->addCallback(
|
||||
[fm, wp]() mutable {
|
||||
auto future = wp.lock();
|
||||
TORCH_INTERNAL_ASSERT(future);
|
||||
if (future->hasError()) {
|
||||
fm->setError(future->tryRetrieveErrorMessage());
|
||||
} else {
|
||||
fm->markCompleted(
|
||||
std::move(*future->value().toCustomClass<Message>()));
|
||||
}
|
||||
}
|
||||
);
|
||||
return fm;
|
||||
}
|
||||
|
||||
protected:
|
||||
const WorkerInfo workerInfo_;
|
||||
const std::unique_ptr<RequestCallback> cb_;
|
||||
|
@ -604,7 +604,7 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
|
||||
});
|
||||
}
|
||||
|
||||
std::shared_ptr<FutureMessage> TensorPipeAgent::send(
|
||||
std::shared_ptr<JitFuture> TensorPipeAgent::send(
|
||||
const WorkerInfo& toWorkerInfo,
|
||||
Message&& requestMessage,
|
||||
const float rpcTimeoutSeconds) {
|
||||
@ -778,8 +778,8 @@ std::shared_ptr<FutureMessage> TensorPipeAgent::send(
|
||||
});
|
||||
});
|
||||
|
||||
return std::shared_ptr<FutureMessage>(
|
||||
futureResponseMessage, &futureResponseMessage->futMsg);
|
||||
return toJitFuture(std::shared_ptr<FutureMessage>(
|
||||
futureResponseMessage, &futureResponseMessage->futMsg));
|
||||
}
|
||||
|
||||
void TensorPipeAgent::pollTimeoutRpcs() {
|
||||
|
@ -181,7 +181,7 @@ class TensorPipeAgent : public RpcAgent {
|
||||
TensorPipeAgent(const TensorPipeAgent&) = delete;
|
||||
TensorPipeAgent& operator=(const TensorPipeAgent&) = delete;
|
||||
|
||||
std::shared_ptr<FutureMessage> send(
|
||||
std::shared_ptr<JitFuture> send(
|
||||
const WorkerInfo& to,
|
||||
Message&& message,
|
||||
const float rpcTimeoutSeconds = kUnsetRpcTimeout) override;
|
||||
|
@ -56,7 +56,7 @@ std::unordered_map<MessageType, float, std::hash<int>> FaultyProcessGroupAgent::
|
||||
return delayMessages;
|
||||
}
|
||||
|
||||
std::shared_ptr<FutureMessage> FaultyProcessGroupAgent::send(
|
||||
std::shared_ptr<JitFuture> FaultyProcessGroupAgent::send(
|
||||
const WorkerInfo& to,
|
||||
Message&& message,
|
||||
const float rpcTimeoutSeconds) {
|
||||
@ -82,7 +82,7 @@ std::shared_ptr<FutureMessage> FaultyProcessGroupAgent::send(
|
||||
fm->setError(makeRPCError(
|
||||
c10::str("Send attempt failed intentionally for ", key),
|
||||
RPCErrorType::INTENTIONAL_FAILURE));
|
||||
return fm;
|
||||
return toJitFuture(std::move(fm));
|
||||
} else {
|
||||
lock.unlock();
|
||||
return ProcessGroupAgent::send(to, std::move(message), rpcTimeoutSeconds);
|
||||
|
@ -43,7 +43,7 @@ class FaultyProcessGroupAgent : public ProcessGroupAgent {
|
||||
int failNumSends = 0);
|
||||
|
||||
// Faulty send function for this class.
|
||||
std::shared_ptr<FutureMessage> send(
|
||||
std::shared_ptr<JitFuture> send(
|
||||
const WorkerInfo& to,
|
||||
Message&& message,
|
||||
const float rpcTimeoutSeconds =
|
||||
|
Reference in New Issue
Block a user