mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Replace FutureMessage with ivalue::Future in RRefContext (#49960)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49960 Test Plan: Imported from OSS Reviewed By: lw Differential Revision: D25730530 Pulled By: mrshenli fbshipit-source-id: 5d54572c653592d79c40aed616266c87307a1ad8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
25ef605132
commit
008206decc
@ -14,11 +14,12 @@ thread_local bool RRefContext::recording_ = false;
|
||||
|
||||
namespace callback {
|
||||
void confirmPendingUser(
|
||||
const FutureMessage& futureMessage,
|
||||
const JitFuture& jitFuture,
|
||||
const ForkId& expectedForkId) {
|
||||
if (!futureMessage.hasError()) {
|
||||
auto msgType = futureMessage.constValue().type();
|
||||
auto rpc = deserializeResponse(futureMessage.constValue(), msgType);
|
||||
if (!jitFuture.hasError()) {
|
||||
auto msgPtr = jitFuture.constValue().toCustomClass<Message>();
|
||||
auto msgType = msgPtr->type();
|
||||
auto rpc = deserializeResponse(*msgPtr, msgType);
|
||||
auto rr = dynamic_cast<RemoteRet*>(rpc.get());
|
||||
TORCH_INTERNAL_ASSERT(rr->forkId() == expectedForkId);
|
||||
} else {
|
||||
@ -34,30 +35,31 @@ void confirmPendingUser(
|
||||
// the user application will use the RRef before the errors are handled. In
|
||||
// this case, errors may not be raised as they have not yet been handled.
|
||||
auto rref_ptr = RRefContext::getInstance().getPendingUser(expectedForkId);
|
||||
auto errorType = getRPCErrorType(futureMessage);
|
||||
rref_ptr->handleError(errorType, futureMessage);
|
||||
auto errorType = getRPCErrorType(jitFuture);
|
||||
rref_ptr->handleError(errorType, jitFuture);
|
||||
}
|
||||
RRefContext::getInstance().delPendingUser(expectedForkId);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<RRef> finishCreatingOwnerRRef(
|
||||
const FutureMessage& futureMessage,
|
||||
const JitFuture& jitFuture,
|
||||
const RRefId& rrefId) {
|
||||
if (futureMessage.hasError()) {
|
||||
if (jitFuture.hasError()) {
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
// We expect to run this callback only after the OwnerRRef has been created,
|
||||
// since this is only invoked when sending to self.
|
||||
auto rref_ptr =
|
||||
ctx.getOwnerRRef(rrefId, /* ensure created */ true)->constValue();
|
||||
auto errorType = getRPCErrorType(futureMessage);
|
||||
rref_ptr->handleError(errorType, futureMessage);
|
||||
auto errorType = getRPCErrorType(jitFuture);
|
||||
rref_ptr->handleError(errorType, jitFuture);
|
||||
// OwnerRRefs do not have a forkId, so don't need to assert here.
|
||||
auto deletedRRef =
|
||||
ctx.delForkOfOwner(rref_ptr->rrefId(), rref_ptr->rrefId());
|
||||
return deletedRRef;
|
||||
} else {
|
||||
auto msgType = futureMessage.constValue().type();
|
||||
auto rpc = deserializeResponse(futureMessage.constValue(), msgType);
|
||||
auto msgPtr = jitFuture.constValue().toCustomClass<Message>();
|
||||
auto msgType = msgPtr->type();
|
||||
auto rpc = deserializeResponse(*msgPtr, msgType);
|
||||
auto rr = dynamic_cast<RemoteRet*>(rpc.get());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
rr->rrefId() == rr->forkId(),
|
||||
@ -102,6 +104,14 @@ std::vector<c10::intrusive_ptr<RRef>> RRefContext::destroyInstance(
|
||||
return deletedRRefs;
|
||||
}
|
||||
|
||||
void RRefContext::handleException(const JitFuture& jitFuture) {
|
||||
if (jitFuture.hasError()) {
|
||||
auto errMsg = jitFuture.tryRetrieveErrorMessage();
|
||||
VLOG(1) << "Got exception: " << errMsg;
|
||||
throw std::runtime_error(errMsg);
|
||||
}
|
||||
}
|
||||
|
||||
void RRefContext::handleException(const FutureMessage& fm) {
|
||||
if (fm.hasError()) {
|
||||
VLOG(1) << "Got exception: " << fm.error()->what();
|
||||
|
Reference in New Issue
Block a user