mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[jit] Make torch::utils::Future and ivalue::future apis closer (#35849)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35849 This change harmonizes some aspects of the api. - torch::utils::Future callback should have no args, like ivalue::future. Many of the lines of this change are related to fixing that up downstream. No args makes the api simpler to use, particularly since many/most of the downstream use cases ignore the passed-in args. It's simple enough to appropriately capture the future in the lambda if necessary. - Add error/hasError methods to ivalue::Future. - Use c10::optional underneath for error to ivalue::Future. - Change markCompleted(error) to setError(error) to ivalue::Future. - Add setValue(FutureError) version to torch::utils::Future ghstack-source-id: 101684435 Test Plan: buck test mode/dev-nosan caffe2/test/... Differential Revision: D20803251 fbshipit-source-id: e3d925287bd9a80d649843eef5f270163f448269
This commit is contained in:
committed by
Facebook GitHub Bot
parent
373dc7c8ef
commit
72b55fea6b
@ -13,23 +13,21 @@ thread_local bool RRefContext::recording = false;
|
||||
|
||||
namespace callback {
|
||||
void confirmPendingUser(
|
||||
const rpc::Message& message,
|
||||
const c10::optional<utils::FutureError>& futErr,
|
||||
const std::shared_ptr<FutureMessage>& futureMessage,
|
||||
const ForkId& expectedForkId) {
|
||||
if (!futErr) {
|
||||
auto rr = RemoteRet::fromMessage(message);
|
||||
if (!futureMessage->hasError()) {
|
||||
auto rr = RemoteRet::fromMessage(futureMessage->constValue());
|
||||
TORCH_INTERNAL_ASSERT(rr->forkId() == expectedForkId);
|
||||
}
|
||||
RRefContext::getInstance().delPendingUser(expectedForkId);
|
||||
// Potentially propagate to the userRRef?
|
||||
RRefContext::handleException(futErr);
|
||||
RRefContext::handleException(futureMessage);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<RRef> finishCreatingOwnerRRef(
|
||||
const Message& message,
|
||||
const c10::optional<utils::FutureError>& futErr) {
|
||||
RRefContext::handleException(futErr);
|
||||
auto rr = RemoteRet::fromMessage(message);
|
||||
const std::shared_ptr<FutureMessage>& futureMessage) {
|
||||
RRefContext::handleException(futureMessage);
|
||||
auto rr = RemoteRet::fromMessage(futureMessage->constValue());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
rr->rrefId() == rr->forkId(),
|
||||
"Expecting an OwnerRRef as RemoteRet but got a fork.");
|
||||
@ -70,12 +68,11 @@ std::vector<c10::intrusive_ptr<RRef>> RRefContext::destroyInstance(
|
||||
return deletedRRefs;
|
||||
}
|
||||
|
||||
void RRefContext::handleException(
|
||||
const c10::optional<utils::FutureError>& futErr) {
|
||||
if (futErr) {
|
||||
void RRefContext::handleException(const std::shared_ptr<FutureMessage>& fm) {
|
||||
if (fm->hasError()) {
|
||||
// TODO: allow users to register an error handler and call it here.
|
||||
VLOG(1) << "Got exception: " << (*futErr).what();
|
||||
throw std::runtime_error((*futErr).what());
|
||||
VLOG(1) << "Got exception: " << fm->error()->what();
|
||||
throw std::runtime_error(fm->error()->what());
|
||||
}
|
||||
}
|
||||
|
||||
@ -181,10 +178,7 @@ void RRefContext::delUser(
|
||||
agent_->getWorkerInfo(owner),
|
||||
RRefUserDelete(rrefId, forkId).toMessage());
|
||||
|
||||
fm->addCallback([](const Message& /* unused */,
|
||||
const c10::optional<utils::FutureError>& futErr) {
|
||||
handleException(futErr);
|
||||
});
|
||||
fm->addCallback([fm]() { handleException(fm); });
|
||||
}
|
||||
}
|
||||
|
||||
@ -393,20 +387,15 @@ void RRefContext::notifyOwnerAndParentOfFork(
|
||||
// with this fork ID.
|
||||
auto fm = agent_->sendWithRetries(
|
||||
agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage());
|
||||
fm->addCallback([](const Message& /* unused */,
|
||||
const c10::optional<utils::FutureError>& futErr) {
|
||||
handleException(futErr);
|
||||
});
|
||||
fm->addCallback([fm]() { handleException(fm); });
|
||||
} else {
|
||||
auto fm = agent_->sendWithRetries(
|
||||
agent_->getWorkerInfo(rref->owner()),
|
||||
RRefForkRequest(rref->rrefId(), forkId).toMessage());
|
||||
|
||||
addPendingUser(forkId, rref);
|
||||
fm->addCallback([this, forkId, parent](
|
||||
const Message& /* unused */,
|
||||
const c10::optional<utils::FutureError>& futErr) {
|
||||
handleException(futErr);
|
||||
fm->addCallback([this, forkId, parent, fm]() {
|
||||
handleException(fm);
|
||||
this->finishForkRequest(forkId, parent);
|
||||
});
|
||||
}
|
||||
@ -551,15 +540,12 @@ std::shared_ptr<torch::utils::Future<bool>> RRefContext::
|
||||
auto remainingRRefs =
|
||||
std::make_shared<std::atomic<uint64_t>>(userTable_.size());
|
||||
for (auto& state : userTable_) {
|
||||
state->future_.addCallback(
|
||||
[future, remainingRRefs](
|
||||
const bool& /* unused */,
|
||||
const c10::optional<utils::FutureError>& /* unused */) {
|
||||
auto localCount = remainingRRefs->fetch_sub(1);
|
||||
if (localCount == 1) {
|
||||
future->markCompleted(true);
|
||||
}
|
||||
});
|
||||
state->future_.addCallback([future, remainingRRefs]() {
|
||||
auto localCount = remainingRRefs->fetch_sub(1);
|
||||
if (localCount == 1) {
|
||||
future->markCompleted(true);
|
||||
}
|
||||
});
|
||||
}
|
||||
userTable_.clear();
|
||||
}
|
||||
@ -577,10 +563,7 @@ void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) {
|
||||
auto fm = agent_->sendWithRetries(
|
||||
agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage());
|
||||
|
||||
fm->addCallback([](const Message& /* unused */,
|
||||
const c10::optional<utils::FutureError>& futErr) {
|
||||
handleException(futErr);
|
||||
});
|
||||
fm->addCallback([fm]() { handleException(fm); });
|
||||
}
|
||||
|
||||
void RRefContext::addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref) {
|
||||
|
Reference in New Issue
Block a user