[jit] Make torch::utils::Future and ivalue::future apis closer (#35849)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35849

This change harmonizes some aspects of the api.

- torch::utils::Future callback should have no args, like ivalue::future.
   Many of the lines of this change are related to fixing that up downstream.

   No args makes the api simpler to use, particularly since many/most of the
   downstream use cases ignore the passed-in args. It's simple enough to
   appropriately capture the future in the lambda if necessary.

 - Add error/hasError methods to ivalue::Future.
 - Use c10::optional underneath for error to ivalue::Future.
 - Change markCompleted(error) to setError(error) to ivalue::Future.
 - Add setValue(FutureError) version to torch::utils::Future

ghstack-source-id: 101684435

Test Plan: buck test mode/dev-nosan caffe2/test/...

Differential Revision: D20803251

fbshipit-source-id: e3d925287bd9a80d649843eef5f270163f448269
This commit is contained in:
Jeremy Lilley
2020-04-07 17:03:00 -07:00
committed by Facebook GitHub Bot
parent 373dc7c8ef
commit 72b55fea6b
16 changed files with 220 additions and 267 deletions

View File

@ -13,23 +13,21 @@ thread_local bool RRefContext::recording = false;
namespace callback {
void confirmPendingUser(
const rpc::Message& message,
const c10::optional<utils::FutureError>& futErr,
const std::shared_ptr<FutureMessage>& futureMessage,
const ForkId& expectedForkId) {
if (!futErr) {
auto rr = RemoteRet::fromMessage(message);
if (!futureMessage->hasError()) {
auto rr = RemoteRet::fromMessage(futureMessage->constValue());
TORCH_INTERNAL_ASSERT(rr->forkId() == expectedForkId);
}
RRefContext::getInstance().delPendingUser(expectedForkId);
// Potentially propagate to the userRRef?
RRefContext::handleException(futErr);
RRefContext::handleException(futureMessage);
}
c10::intrusive_ptr<RRef> finishCreatingOwnerRRef(
const Message& message,
const c10::optional<utils::FutureError>& futErr) {
RRefContext::handleException(futErr);
auto rr = RemoteRet::fromMessage(message);
const std::shared_ptr<FutureMessage>& futureMessage) {
RRefContext::handleException(futureMessage);
auto rr = RemoteRet::fromMessage(futureMessage->constValue());
TORCH_INTERNAL_ASSERT(
rr->rrefId() == rr->forkId(),
"Expecting an OwnerRRef as RemoteRet but got a fork.");
@ -70,12 +68,11 @@ std::vector<c10::intrusive_ptr<RRef>> RRefContext::destroyInstance(
return deletedRRefs;
}
void RRefContext::handleException(
const c10::optional<utils::FutureError>& futErr) {
if (futErr) {
void RRefContext::handleException(const std::shared_ptr<FutureMessage>& fm) {
if (fm->hasError()) {
// TODO: allow users to register an error handler and call it here.
VLOG(1) << "Got exception: " << (*futErr).what();
throw std::runtime_error((*futErr).what());
VLOG(1) << "Got exception: " << fm->error()->what();
throw std::runtime_error(fm->error()->what());
}
}
@ -181,10 +178,7 @@ void RRefContext::delUser(
agent_->getWorkerInfo(owner),
RRefUserDelete(rrefId, forkId).toMessage());
fm->addCallback([](const Message& /* unused */,
const c10::optional<utils::FutureError>& futErr) {
handleException(futErr);
});
fm->addCallback([fm]() { handleException(fm); });
}
}
@ -393,20 +387,15 @@ void RRefContext::notifyOwnerAndParentOfFork(
// with this fork ID.
auto fm = agent_->sendWithRetries(
agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage());
fm->addCallback([](const Message& /* unused */,
const c10::optional<utils::FutureError>& futErr) {
handleException(futErr);
});
fm->addCallback([fm]() { handleException(fm); });
} else {
auto fm = agent_->sendWithRetries(
agent_->getWorkerInfo(rref->owner()),
RRefForkRequest(rref->rrefId(), forkId).toMessage());
addPendingUser(forkId, rref);
fm->addCallback([this, forkId, parent](
const Message& /* unused */,
const c10::optional<utils::FutureError>& futErr) {
handleException(futErr);
fm->addCallback([this, forkId, parent, fm]() {
handleException(fm);
this->finishForkRequest(forkId, parent);
});
}
@ -551,15 +540,12 @@ std::shared_ptr<torch::utils::Future<bool>> RRefContext::
auto remainingRRefs =
std::make_shared<std::atomic<uint64_t>>(userTable_.size());
for (auto& state : userTable_) {
state->future_.addCallback(
[future, remainingRRefs](
const bool& /* unused */,
const c10::optional<utils::FutureError>& /* unused */) {
auto localCount = remainingRRefs->fetch_sub(1);
if (localCount == 1) {
future->markCompleted(true);
}
});
state->future_.addCallback([future, remainingRRefs]() {
auto localCount = remainingRRefs->fetch_sub(1);
if (localCount == 1) {
future->markCompleted(true);
}
});
}
userTable_.clear();
}
@ -577,10 +563,7 @@ void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) {
auto fm = agent_->sendWithRetries(
agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage());
fm->addCallback([](const Message& /* unused */,
const c10::optional<utils::FutureError>& futErr) {
handleException(futErr);
});
fm->addCallback([fm]() { handleException(fm); });
}
void RRefContext::addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref) {