[future] Avoid some future callback self-captures. (#36502)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36502

We're sometimes deleting futures without completing them (discovered by logging),
and we've recently noticed a slow memory leak.

This change migrates the future lambda cases where there was self-capture.
 - In some cases, we use weak_ptr<>, plus .lock()/assert in the lambda callback.
   This avoids the reference cycle. We use this primarily in the case where the
   value ends up being moved in the callback (something we want to be careful about)

 - We also add a convenience api to Future where the completed Future is returned as an arg.
   This allows us to avoid self-capture, though it assumes that the markCompleted()
   caller is persisting the future for the markCompleted() duration (this has been the case)

ghstack-source-id: 102130672

Test Plan: ctr_mobile_feed, buck test mode/dev-nosan caffe2/test/...

Differential Revision: D20998905

fbshipit-source-id: 7dd52fe4e567a5dea20e8d43862fc2335fd3ce16
This commit is contained in:
Jeremy Lilley
2020-04-14 17:49:40 -07:00
committed by Facebook GitHub Bot
parent 1a0b95e7e4
commit 37aab14d14
11 changed files with 210 additions and 195 deletions

View File

@ -13,10 +13,10 @@ thread_local bool RRefContext::recording_ = false;
namespace callback {
void confirmPendingUser(
const std::shared_ptr<FutureMessage>& futureMessage,
const FutureMessage& futureMessage,
const ForkId& expectedForkId) {
if (!futureMessage->hasError()) {
auto rr = RemoteRet::fromMessage(futureMessage->constValue());
if (!futureMessage.hasError()) {
auto rr = RemoteRet::fromMessage(futureMessage.constValue());
TORCH_INTERNAL_ASSERT(rr->forkId() == expectedForkId);
}
RRefContext::getInstance().delPendingUser(expectedForkId);
@ -25,9 +25,9 @@ void confirmPendingUser(
}
c10::intrusive_ptr<RRef> finishCreatingOwnerRRef(
const std::shared_ptr<FutureMessage>& futureMessage) {
const FutureMessage& futureMessage) {
RRefContext::handleException(futureMessage);
auto rr = RemoteRet::fromMessage(futureMessage->constValue());
auto rr = RemoteRet::fromMessage(futureMessage.constValue());
TORCH_INTERNAL_ASSERT(
rr->rrefId() == rr->forkId(),
"Expecting an OwnerRRef as RemoteRet but got a fork.");
@ -68,11 +68,11 @@ std::vector<c10::intrusive_ptr<RRef>> RRefContext::destroyInstance(
return deletedRRefs;
}
void RRefContext::handleException(const std::shared_ptr<FutureMessage>& fm) {
if (fm->hasError()) {
void RRefContext::handleException(const FutureMessage& fm) {
if (fm.hasError()) {
// TODO: allow users to register an error handler and call it here.
VLOG(1) << "Got exception: " << fm->error()->what();
throw std::runtime_error(fm->error()->what());
VLOG(1) << "Got exception: " << fm.error()->what();
throw std::runtime_error(fm.error()->what());
}
}
@ -178,7 +178,7 @@ void RRefContext::delUser(
agent_->getWorkerInfo(owner),
RRefUserDelete(rrefId, forkId).toMessage());
fm->addCallback([fm]() { handleException(fm); });
fm->addCallback([](const FutureMessage& fm) { handleException(fm); });
}
}
@ -387,14 +387,14 @@ void RRefContext::notifyOwnerAndParentOfFork(
// with this fork ID.
auto fm = agent_->sendWithRetries(
agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage());
fm->addCallback([fm]() { handleException(fm); });
fm->addCallback([](const FutureMessage& fm) { handleException(fm); });
} else {
auto fm = agent_->sendWithRetries(
agent_->getWorkerInfo(rref->owner()),
RRefForkRequest(rref->rrefId(), forkId).toMessage());
addPendingUser(forkId, rref);
fm->addCallback([this, forkId, parent, fm]() {
fm->addCallback([this, forkId, parent](const FutureMessage& fm) {
handleException(fm);
this->finishForkRequest(forkId, parent);
});
@ -564,7 +564,7 @@ void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) {
auto fm = agent_->sendWithRetries(
agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage());
fm->addCallback([fm]() { handleException(fm); });
fm->addCallback([](const FutureMessage& fm) { handleException(fm); });
}
void RRefContext::addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref) {