Simplify process(Script|Python)(Remote)?Call (#57857)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57857

There used to be a whole lot of methods: `processPythonCall`, `processScriptCall`, `processScriptRemoteCall`, `processPythonRemoteCall`, `processScriptCallOp`, `processBaseScriptRemoteCall` and `processScriptRemoteCallOp`. Thanks to the previous simplification, we can now drop all but the first four, which map nicely 1:1 to the four message types we need to handle. Also their signatures become much simpler: they take an RPC command and return a future.
ghstack-source-id: 129567070

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D28253848

fbshipit-source-id: e0e45345c414a96900f9d70ee555359d28908833
This commit is contained in:
Luca Wehrstedt
2021-05-21 13:10:24 -07:00
committed by Facebook GitHub Bot
parent c96a05d148
commit cd9dbbd93a
4 changed files with 92 additions and 152 deletions

View File

@ -137,74 +137,75 @@ std::unique_ptr<RpcCommandBase> RequestCallbackImpl::
return pythonRpc ? std::move(pythonRpc) : std::move(rpc);
}
void RequestCallbackImpl::processScriptCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& responseFuture) const {
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processScriptCall(
RpcCommandBase& rpc) const {
auto& scriptCall = static_cast<ScriptCall&>(rpc);
auto& stack = scriptCall.stackRef();
c10::intrusive_ptr<JitFuture> future;
if (scriptCall.hasOp()) {
processScriptCallOp(scriptCall, markComplete, stack);
return;
future = runJitOperator(*scriptCall.op(), scriptCall.stackRef());
} else {
future = runJitFunction(
scriptCall.qualifiedName(),
scriptCall.stackRef(),
scriptCall.isAsyncExecution());
}
auto jitFuture = runJitFunction(
scriptCall.qualifiedName(), stack, scriptCall.isAsyncExecution());
jitFuture->addCallback([responseFuture, markComplete](JitFuture& jitFuture) {
if (jitFuture.hasError()) {
responseFuture->setError(jitFuture.exception_ptr());
} else {
responseFuture->markCompleted(c10::make_intrusive<Message>(
ScriptResp(jitFuture.value()).toMessage()));
}
});
return future->then(
[](JitFuture& jitFuture) {
return c10::make_intrusive<Message>(
ScriptResp(jitFuture.value()).toMessage());
},
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
}
void RequestCallbackImpl::processPythonCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& responseFuture) const {
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonCall(
RpcCommandBase& rpc) const {
auto& upc = static_cast<UnpickledPythonCall&>(rpc);
auto future = runPythonFunction(upc.pythonUdf(), upc.isAsyncExecution());
future->addCallback([responseFuture](JitFuture& future) {
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
py::gil_scoped_acquire acquire;
auto serializedPyObj =
pythonRpcHandler.serialize(jit::toPyObject(future.value()));
responseFuture->markCompleted(c10::make_intrusive<Message>(
PythonResp(std::move(serializedPyObj)).toMessage()));
});
return future->then(
[](JitFuture& future) {
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
py::gil_scoped_acquire acquire;
auto serializedPyObj = pythonRpcHandler.serialize(
jit::toPyObject(future.value()));
return c10::make_intrusive<Message>(
PythonResp(std::move(serializedPyObj)).toMessage());
},
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
}
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processScriptRemoteCall(
ScriptRemoteCall& scriptRemoteCall,
std::vector<at::IValue>& stack) const {
RpcCommandBase& rpc) const {
auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
c10::intrusive_ptr<JitFuture> future;
if (scriptRemoteCall.hasOp()) {
return processScriptRemoteCallOp(scriptRemoteCall, stack);
future =
runJitOperator(*scriptRemoteCall.op(), scriptRemoteCall.stackRef());
} else {
future = runJitFunction(
scriptRemoteCall.qualifiedName(),
scriptRemoteCall.stackRef(),
scriptRemoteCall.isAsyncExecution());
}
return runJitFunction(
scriptRemoteCall.qualifiedName(),
stack,
scriptRemoteCall.isAsyncExecution());
return assignOwnerRRef(
scriptRemoteCall.retRRefId(),
scriptRemoteCall.retForkId(),
std::move(future),
/*lsctx=*/nullptr);
}
void RequestCallbackImpl::processPythonRemoteCall(
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonRemoteCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> lsctx) const {
auto& uprc = static_cast<UnpickledPythonRemoteCall&>(rpc);
auto future = runPythonFunction(uprc.pythonUdf(), uprc.isAsyncExecution());
assignOwnerRRef(
uprc.rrefId(),
uprc.forkId(),
std::move(future),
responseFuture,
std::move(lsctx));
return assignOwnerRRef(
uprc.rrefId(), uprc.forkId(), std::move(future), std::move(lsctx));
}
void RequestCallbackImpl::processPythonRRefFetchCall(

View File

@ -15,24 +15,17 @@ class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython {
std::unique_ptr<RpcCommandBase> rpc,
const MessageType& messageType) const override;
void processPythonCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
c10::intrusive_ptr<JitFuture> processPythonCall(
RpcCommandBase& rpc) const override;
void processScriptCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
c10::intrusive_ptr<JitFuture> processScriptCall(
RpcCommandBase& rpc) const override;
c10::intrusive_ptr<JitFuture> processScriptRemoteCall(
ScriptRemoteCall& scriptRemoteCall,
std::vector<at::IValue>& stack) const override;
RpcCommandBase& rpc) const override;
void processPythonRemoteCall(
c10::intrusive_ptr<JitFuture> processPythonRemoteCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> ctx) const override;
void processPythonRRefFetchCall(

View File

@ -136,48 +136,37 @@ c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRpcWithErrors(
}
}
void RequestCallbackNoPython::processScriptCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& /* unused */) const {
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processScriptCall(
RpcCommandBase& rpc) const {
auto& scriptCall = static_cast<ScriptCall&>(rpc);
auto& stack = scriptCall.stackRef();
TORCH_CHECK(
scriptCall.hasOp(), "Only supports the case where ScriptCall has an op");
processScriptCallOp(scriptCall, markComplete, stack);
auto future = runJitOperator(*scriptCall.op(), scriptCall.stackRef());
return future->then(
[](JitFuture& future) {
return c10::make_intrusive<Message>(
ScriptResp(future.value()).toMessage());
},
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
}
void RequestCallbackNoPython::processScriptCallOp(
ScriptCall& scriptCall,
const std::function<void(Message)>& markComplete,
std::vector<at::IValue>& stack) const {
TORCH_INTERNAL_ASSERT(scriptCall.hasOp());
auto future = runJitOperator(*scriptCall.op(), stack);
future->addCallback([markComplete](JitFuture& future) {
markComplete(ScriptResp(future.value()).toMessage());
});
}
void RequestCallbackNoPython::processPythonCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& /* unused */) const {
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processPythonCall(
RpcCommandBase& rpc) const {
C10_THROW_ERROR(Error, "Python call not supported!");
}
void RequestCallbackNoPython::processPythonRemoteCall(
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processPythonRemoteCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& /* unused */,
std::shared_ptr<LazyStreamContext> /* unused */) const {
C10_THROW_ERROR(Error, "Python call not supported!");
}
void RequestCallbackNoPython::assignOwnerRRef(
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::assignOwnerRRef(
const RRefId& rrefId,
const RRefId& forkId,
c10::intrusive_ptr<JitFuture> valueFuture,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> lsctx) const {
auto& ctx = RRefContext::getInstance();
@ -201,52 +190,36 @@ void RequestCallbackNoPython::assignOwnerRRef(
ctx.addForkOfOwner(rrefId, forkId);
}
valueFuture->addCallback(
[ownerRRef, rrefId, forkId, responseFuture, lsctx = std::move(lsctx)](
JitFuture& future) {
return valueFuture->then(
[ownerRRef, rrefId, forkId, lsctx = std::move(lsctx)](JitFuture& future) {
if (future.hasError()) {
ownerRRef->setError(future.exception_ptr());
} else {
ownerRRef->recordAllStreams(lsctx);
ownerRRef->setValue(future.value());
}
responseFuture->markCompleted(c10::make_intrusive<Message>(
RemoteRet(rrefId, forkId).toMessage()));
});
return c10::make_intrusive<Message>(
RemoteRet(rrefId, forkId).toMessage());
},
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processScriptRemoteCall(
ScriptRemoteCall& scriptRemoteCall,
std::vector<at::IValue>& stack) const {
RpcCommandBase& rpc) const {
auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
TORCH_CHECK(
scriptRemoteCall.hasOp(), "ScriptRemoteCall needs to have an op!");
return processScriptRemoteCallOp(scriptRemoteCall, stack);
}
auto future =
runJitOperator(*scriptRemoteCall.op(), scriptRemoteCall.stackRef());
void RequestCallbackNoPython::processBaseScriptRemoteCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& responseFuture) const {
auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
auto& stack = scriptRemoteCall.stackRef();
auto jitFuture = processScriptRemoteCall(scriptRemoteCall, stack);
assignOwnerRRef(
return assignOwnerRRef(
scriptRemoteCall.retRRefId(),
scriptRemoteCall.retForkId(),
std::move(jitFuture),
responseFuture,
std::move(future),
/*lsctx=*/nullptr);
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processScriptRemoteCallOp(
ScriptRemoteCall& scriptRemoteCall,
std::vector<at::IValue>& stack) const {
TORCH_INTERNAL_ASSERT(scriptRemoteCall.hasOp());
return runJitOperator(*scriptRemoteCall.op(), stack);
}
void RequestCallbackNoPython::processScriptRRefFetchCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
@ -556,21 +529,16 @@ c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRpc(
// to a python object.
switch (messageType) {
case MessageType::SCRIPT_CALL: {
processScriptCall(rpc, markComplete, responseFuture);
return responseFuture;
return processScriptCall(rpc);
}
case MessageType::PYTHON_CALL: {
processPythonCall(rpc, markComplete, responseFuture);
return responseFuture;
return processPythonCall(rpc);
}
case MessageType::SCRIPT_REMOTE_CALL: {
processBaseScriptRemoteCall(rpc, markComplete, responseFuture);
return responseFuture;
return processScriptRemoteCall(rpc);
}
case MessageType::PYTHON_REMOTE_CALL: {
processPythonRemoteCall(
rpc, markComplete, responseFuture, std::move(ctx));
return responseFuture;
return processPythonRemoteCall(rpc, std::move(ctx));
}
case MessageType::SCRIPT_RREF_FETCH_CALL: {
processScriptRRefFetchCall(rpc, markComplete, responseFuture);

View File

@ -23,45 +23,23 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback {
std::unique_ptr<RpcCommandBase> rpc,
const MessageType& messageType) const;
virtual void processScriptCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& responseFuture) const;
virtual c10::intrusive_ptr<JitFuture> processScriptCall(
RpcCommandBase& rpc) const;
void processScriptCallOp(
ScriptCall& scriptCall,
const std::function<void(Message)>& markComplete,
std::vector<at::IValue>& stack) const;
virtual c10::intrusive_ptr<JitFuture> processPythonCall(
RpcCommandBase& rpc) const;
virtual void processPythonCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& responseFuture) const;
void assignOwnerRRef(
c10::intrusive_ptr<JitFuture> assignOwnerRRef(
const RRefId& rrefId,
const RRefId& forkId,
c10::intrusive_ptr<JitFuture> valueFuture,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> lsctx) const;
virtual c10::intrusive_ptr<JitFuture> processScriptRemoteCall(
ScriptRemoteCall& scriptRemoteCall,
std::vector<at::IValue>& stack) const;
RpcCommandBase& rpc) const;
void processBaseScriptRemoteCall(
virtual c10::intrusive_ptr<JitFuture> processPythonRemoteCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& responseFuture) const;
c10::intrusive_ptr<JitFuture> processScriptRemoteCallOp(
ScriptRemoteCall& scriptRemoteCall,
std::vector<at::IValue>& stack) const;
virtual void processPythonRemoteCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> ctx) const;
void processScriptRRefFetchCall(