Replace FutureMessage with ivalue::Future in RRefContext (#49960)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49960

Test Plan: Imported from OSS

Reviewed By: lw

Differential Revision: D25730530

Pulled By: mrshenli

fbshipit-source-id: 5d54572c653592d79c40aed616266c87307a1ad8
This commit is contained in:
Shen Li
2021-01-07 19:43:44 -08:00
committed by Facebook GitHub Bot
parent 25ef605132
commit 008206decc
10 changed files with 72 additions and 78 deletions

View File

@ -14,11 +14,12 @@ thread_local bool RRefContext::recording_ = false;
namespace callback {
void confirmPendingUser(
const FutureMessage& futureMessage,
const JitFuture& jitFuture,
const ForkId& expectedForkId) {
if (!futureMessage.hasError()) {
auto msgType = futureMessage.constValue().type();
auto rpc = deserializeResponse(futureMessage.constValue(), msgType);
if (!jitFuture.hasError()) {
auto msgPtr = jitFuture.constValue().toCustomClass<Message>();
auto msgType = msgPtr->type();
auto rpc = deserializeResponse(*msgPtr, msgType);
auto rr = dynamic_cast<RemoteRet*>(rpc.get());
TORCH_INTERNAL_ASSERT(rr->forkId() == expectedForkId);
} else {
@ -34,30 +35,31 @@ void confirmPendingUser(
// the user application will use the RRef before the errors are handled. In
// this case, errors may not be raised as they have not yet been handled.
auto rref_ptr = RRefContext::getInstance().getPendingUser(expectedForkId);
auto errorType = getRPCErrorType(futureMessage);
rref_ptr->handleError(errorType, futureMessage);
auto errorType = getRPCErrorType(jitFuture);
rref_ptr->handleError(errorType, jitFuture);
}
RRefContext::getInstance().delPendingUser(expectedForkId);
}
c10::intrusive_ptr<RRef> finishCreatingOwnerRRef(
const FutureMessage& futureMessage,
const JitFuture& jitFuture,
const RRefId& rrefId) {
if (futureMessage.hasError()) {
if (jitFuture.hasError()) {
auto& ctx = RRefContext::getInstance();
// We expect to run this callback only after the OwnerRRef has been created,
// since this is only invoked when sending to self.
auto rref_ptr =
ctx.getOwnerRRef(rrefId, /* ensure created */ true)->constValue();
auto errorType = getRPCErrorType(futureMessage);
rref_ptr->handleError(errorType, futureMessage);
auto errorType = getRPCErrorType(jitFuture);
rref_ptr->handleError(errorType, jitFuture);
// OwnerRRefs do not have a forkId, so don't need to assert here.
auto deletedRRef =
ctx.delForkOfOwner(rref_ptr->rrefId(), rref_ptr->rrefId());
return deletedRRef;
} else {
auto msgType = futureMessage.constValue().type();
auto rpc = deserializeResponse(futureMessage.constValue(), msgType);
auto msgPtr = jitFuture.constValue().toCustomClass<Message>();
auto msgType = msgPtr->type();
auto rpc = deserializeResponse(*msgPtr, msgType);
auto rr = dynamic_cast<RemoteRet*>(rpc.get());
TORCH_INTERNAL_ASSERT(
rr->rrefId() == rr->forkId(),
@ -102,6 +104,14 @@ std::vector<c10::intrusive_ptr<RRef>> RRefContext::destroyInstance(
return deletedRRefs;
}
void RRefContext::handleException(const JitFuture& jitFuture) {
if (jitFuture.hasError()) {
auto errMsg = jitFuture.tryRetrieveErrorMessage();
VLOG(1) << "Got exception: " << errMsg;
throw std::runtime_error(errMsg);
}
}
void RRefContext::handleException(const FutureMessage& fm) {
if (fm.hasError()) {
VLOG(1) << "Got exception: " << fm.error()->what();