Always use intrusive_ptr for Message (1 out of 2) (#58422)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58422

Similar to Future (which I tackled recently), Message is an ivalue type (a "custom class" one), and the natural way to represent it is inside an intrusive_ptr. However in the RPC code we had a mix of usages, often passing Message by value. This has undesirable consequences, as it could easily trigger a copy by accident, which I believe is why in many places we accepted _rvalue references_ to Message, in order to force the caller to move. In my experience this is non-idiomatic in C++ (normally a function signature specifies how the function consumes its arguments, and it's up to the caller to then decide whether to copy or move).

By moving to intrusive_ptr everywhere I think we eliminate and simplify many of the problems above.

In this PR I do half of the migration, by updating everything except the `toMessageImpl` methods, which will come in the next PR.
ghstack-source-id: 129567053

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D28474878

fbshipit-source-id: 5b76d45e05f6fa58c831e369c5c964d126187a6c
This commit is contained in:
Luca Wehrstedt
2021-05-21 13:10:24 -07:00
committed by Facebook GitHub Bot
parent 35ea8779da
commit 4d704e607d
24 changed files with 159 additions and 176 deletions

View File

@ -50,7 +50,7 @@ void RpcAgent::shutdown() {
c10::intrusive_ptr<JitFuture> RpcAgent::sendWithRetries(
const WorkerInfo& to,
Message&& message,
c10::intrusive_ptr<Message> message,
RpcRetryOptions retryOptions) {
TORCH_CHECK(retryOptions.maxRetries >= 0, "maxRetries cannot be negative.");
TORCH_CHECK(
@ -64,15 +64,13 @@ c10::intrusive_ptr<JitFuture> RpcAgent::sendWithRetries(
c10::make_intrusive<JitFuture>(at::AnyClassType::get(), getDevices());
steady_clock_time_point newTime =
computeNewRpcRetryTime(retryOptions, /* retryCount */ 0);
// Making a copy of the message so it can be retried after the first send.
Message msgCopy = message;
auto jitFuture = send(to, std::move(message));
auto firstRetryRpc = std::make_shared<RpcRetryInfo>(
to,
std::move(msgCopy),
message,
originalFuture,
/* retryCount */ 0,
retryOptions);
auto jitFuture = send(to, std::move(message));
jitFuture->addCallback([this, newTime, firstRetryRpc](JitFuture& future) {
rpcRetryCallback(future, newTime, firstRetryRpc);
});
@ -122,8 +120,6 @@ void RpcAgent::retryExpiredRpcs() {
for (auto it = earliestRpcList.begin(); it != earliestRpcList.end();
/* no increment */) {
auto& earliestRpc = *it;
// Making a copy of the message so it can be retried in the future.
Message msgCopy = earliestRpc->message_;
c10::intrusive_ptr<JitFuture> jitFuture;
// send() will throw an exception if an RPC is retried while the agent is
@ -131,7 +127,7 @@ void RpcAgent::retryExpiredRpcs() {
// with an error, since this RPC never succeeded and can no longer be
// retried.
try {
jitFuture = send(earliestRpc->to_, std::move(msgCopy));
jitFuture = send(earliestRpc->to_, earliestRpc->message_);
futures.emplace_back(jitFuture, earliestRpc);
} catch (std::exception& e) {
// We must store the futures and exception messages here and only mark