mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Simplify process(Script|Python)(Remote)?Call (#57857)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57857 There used to be a whole lot of methods: `processPythonCall`, `processScriptCall`, `processScriptRemoteCall`, `processPythonRemoteCall`, `processScriptCallOp`, `processBaseScriptRemoteCall` and `processScriptRemoteCallOp`. Thanks to the previous simplification, we can now drop all but the first four, which map nicely 1:1 to the four message types we need to handle. Also their signatures become much simpler: they take an RPC command and return a future. ghstack-source-id: 129567070 Test Plan: CI Reviewed By: mrshenli Differential Revision: D28253848 fbshipit-source-id: e0e45345c414a96900f9d70ee555359d28908833
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c96a05d148
commit
cd9dbbd93a
@ -137,74 +137,75 @@ std::unique_ptr<RpcCommandBase> RequestCallbackImpl::
|
||||
return pythonRpc ? std::move(pythonRpc) : std::move(rpc);
|
||||
}
|
||||
|
||||
void RequestCallbackImpl::processScriptCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
const c10::intrusive_ptr<JitFuture>& responseFuture) const {
|
||||
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processScriptCall(
|
||||
RpcCommandBase& rpc) const {
|
||||
auto& scriptCall = static_cast<ScriptCall&>(rpc);
|
||||
auto& stack = scriptCall.stackRef();
|
||||
|
||||
c10::intrusive_ptr<JitFuture> future;
|
||||
if (scriptCall.hasOp()) {
|
||||
processScriptCallOp(scriptCall, markComplete, stack);
|
||||
return;
|
||||
future = runJitOperator(*scriptCall.op(), scriptCall.stackRef());
|
||||
} else {
|
||||
future = runJitFunction(
|
||||
scriptCall.qualifiedName(),
|
||||
scriptCall.stackRef(),
|
||||
scriptCall.isAsyncExecution());
|
||||
}
|
||||
|
||||
auto jitFuture = runJitFunction(
|
||||
scriptCall.qualifiedName(), stack, scriptCall.isAsyncExecution());
|
||||
|
||||
jitFuture->addCallback([responseFuture, markComplete](JitFuture& jitFuture) {
|
||||
if (jitFuture.hasError()) {
|
||||
responseFuture->setError(jitFuture.exception_ptr());
|
||||
} else {
|
||||
responseFuture->markCompleted(c10::make_intrusive<Message>(
|
||||
ScriptResp(jitFuture.value()).toMessage()));
|
||||
}
|
||||
});
|
||||
return future->then(
|
||||
[](JitFuture& jitFuture) {
|
||||
return c10::make_intrusive<Message>(
|
||||
ScriptResp(jitFuture.value()).toMessage());
|
||||
},
|
||||
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
|
||||
}
|
||||
|
||||
void RequestCallbackImpl::processPythonCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
const c10::intrusive_ptr<JitFuture>& responseFuture) const {
|
||||
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonCall(
|
||||
RpcCommandBase& rpc) const {
|
||||
auto& upc = static_cast<UnpickledPythonCall&>(rpc);
|
||||
auto future = runPythonFunction(upc.pythonUdf(), upc.isAsyncExecution());
|
||||
|
||||
future->addCallback([responseFuture](JitFuture& future) {
|
||||
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto serializedPyObj =
|
||||
pythonRpcHandler.serialize(jit::toPyObject(future.value()));
|
||||
responseFuture->markCompleted(c10::make_intrusive<Message>(
|
||||
PythonResp(std::move(serializedPyObj)).toMessage()));
|
||||
});
|
||||
return future->then(
|
||||
[](JitFuture& future) {
|
||||
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto serializedPyObj = pythonRpcHandler.serialize(
|
||||
jit::toPyObject(future.value()));
|
||||
return c10::make_intrusive<Message>(
|
||||
PythonResp(std::move(serializedPyObj)).toMessage());
|
||||
},
|
||||
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processScriptRemoteCall(
|
||||
ScriptRemoteCall& scriptRemoteCall,
|
||||
std::vector<at::IValue>& stack) const {
|
||||
RpcCommandBase& rpc) const {
|
||||
auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
|
||||
|
||||
c10::intrusive_ptr<JitFuture> future;
|
||||
if (scriptRemoteCall.hasOp()) {
|
||||
return processScriptRemoteCallOp(scriptRemoteCall, stack);
|
||||
future =
|
||||
runJitOperator(*scriptRemoteCall.op(), scriptRemoteCall.stackRef());
|
||||
} else {
|
||||
future = runJitFunction(
|
||||
scriptRemoteCall.qualifiedName(),
|
||||
scriptRemoteCall.stackRef(),
|
||||
scriptRemoteCall.isAsyncExecution());
|
||||
}
|
||||
|
||||
return runJitFunction(
|
||||
scriptRemoteCall.qualifiedName(),
|
||||
stack,
|
||||
scriptRemoteCall.isAsyncExecution());
|
||||
return assignOwnerRRef(
|
||||
scriptRemoteCall.retRRefId(),
|
||||
scriptRemoteCall.retForkId(),
|
||||
std::move(future),
|
||||
/*lsctx=*/nullptr);
|
||||
}
|
||||
|
||||
void RequestCallbackImpl::processPythonRemoteCall(
|
||||
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonRemoteCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
const c10::intrusive_ptr<JitFuture>& responseFuture,
|
||||
std::shared_ptr<LazyStreamContext> lsctx) const {
|
||||
auto& uprc = static_cast<UnpickledPythonRemoteCall&>(rpc);
|
||||
auto future = runPythonFunction(uprc.pythonUdf(), uprc.isAsyncExecution());
|
||||
|
||||
assignOwnerRRef(
|
||||
uprc.rrefId(),
|
||||
uprc.forkId(),
|
||||
std::move(future),
|
||||
responseFuture,
|
||||
std::move(lsctx));
|
||||
return assignOwnerRRef(
|
||||
uprc.rrefId(), uprc.forkId(), std::move(future), std::move(lsctx));
|
||||
}
|
||||
|
||||
void RequestCallbackImpl::processPythonRRefFetchCall(
|
||||
|
@ -15,24 +15,17 @@ class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython {
|
||||
std::unique_ptr<RpcCommandBase> rpc,
|
||||
const MessageType& messageType) const override;
|
||||
|
||||
void processPythonCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
|
||||
c10::intrusive_ptr<JitFuture> processPythonCall(
|
||||
RpcCommandBase& rpc) const override;
|
||||
|
||||
void processScriptCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
|
||||
c10::intrusive_ptr<JitFuture> processScriptCall(
|
||||
RpcCommandBase& rpc) const override;
|
||||
|
||||
c10::intrusive_ptr<JitFuture> processScriptRemoteCall(
|
||||
ScriptRemoteCall& scriptRemoteCall,
|
||||
std::vector<at::IValue>& stack) const override;
|
||||
RpcCommandBase& rpc) const override;
|
||||
|
||||
void processPythonRemoteCall(
|
||||
c10::intrusive_ptr<JitFuture> processPythonRemoteCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
const c10::intrusive_ptr<JitFuture>& responseFuture,
|
||||
std::shared_ptr<LazyStreamContext> ctx) const override;
|
||||
|
||||
void processPythonRRefFetchCall(
|
||||
|
@ -136,48 +136,37 @@ c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRpcWithErrors(
|
||||
}
|
||||
}
|
||||
|
||||
void RequestCallbackNoPython::processScriptCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
const c10::intrusive_ptr<JitFuture>& /* unused */) const {
|
||||
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processScriptCall(
|
||||
RpcCommandBase& rpc) const {
|
||||
auto& scriptCall = static_cast<ScriptCall&>(rpc);
|
||||
auto& stack = scriptCall.stackRef();
|
||||
|
||||
TORCH_CHECK(
|
||||
scriptCall.hasOp(), "Only supports the case where ScriptCall has an op");
|
||||
processScriptCallOp(scriptCall, markComplete, stack);
|
||||
auto future = runJitOperator(*scriptCall.op(), scriptCall.stackRef());
|
||||
|
||||
return future->then(
|
||||
[](JitFuture& future) {
|
||||
return c10::make_intrusive<Message>(
|
||||
ScriptResp(future.value()).toMessage());
|
||||
},
|
||||
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
|
||||
}
|
||||
|
||||
void RequestCallbackNoPython::processScriptCallOp(
|
||||
ScriptCall& scriptCall,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
std::vector<at::IValue>& stack) const {
|
||||
TORCH_INTERNAL_ASSERT(scriptCall.hasOp());
|
||||
auto future = runJitOperator(*scriptCall.op(), stack);
|
||||
future->addCallback([markComplete](JitFuture& future) {
|
||||
markComplete(ScriptResp(future.value()).toMessage());
|
||||
});
|
||||
}
|
||||
|
||||
void RequestCallbackNoPython::processPythonCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
const c10::intrusive_ptr<JitFuture>& /* unused */) const {
|
||||
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processPythonCall(
|
||||
RpcCommandBase& rpc) const {
|
||||
C10_THROW_ERROR(Error, "Python call not supported!");
|
||||
}
|
||||
|
||||
void RequestCallbackNoPython::processPythonRemoteCall(
|
||||
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processPythonRemoteCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
const c10::intrusive_ptr<JitFuture>& /* unused */,
|
||||
std::shared_ptr<LazyStreamContext> /* unused */) const {
|
||||
C10_THROW_ERROR(Error, "Python call not supported!");
|
||||
}
|
||||
|
||||
void RequestCallbackNoPython::assignOwnerRRef(
|
||||
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::assignOwnerRRef(
|
||||
const RRefId& rrefId,
|
||||
const RRefId& forkId,
|
||||
c10::intrusive_ptr<JitFuture> valueFuture,
|
||||
const c10::intrusive_ptr<JitFuture>& responseFuture,
|
||||
std::shared_ptr<LazyStreamContext> lsctx) const {
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
|
||||
@ -201,52 +190,36 @@ void RequestCallbackNoPython::assignOwnerRRef(
|
||||
ctx.addForkOfOwner(rrefId, forkId);
|
||||
}
|
||||
|
||||
valueFuture->addCallback(
|
||||
[ownerRRef, rrefId, forkId, responseFuture, lsctx = std::move(lsctx)](
|
||||
JitFuture& future) {
|
||||
return valueFuture->then(
|
||||
[ownerRRef, rrefId, forkId, lsctx = std::move(lsctx)](JitFuture& future) {
|
||||
if (future.hasError()) {
|
||||
ownerRRef->setError(future.exception_ptr());
|
||||
} else {
|
||||
ownerRRef->recordAllStreams(lsctx);
|
||||
ownerRRef->setValue(future.value());
|
||||
}
|
||||
responseFuture->markCompleted(c10::make_intrusive<Message>(
|
||||
RemoteRet(rrefId, forkId).toMessage()));
|
||||
});
|
||||
return c10::make_intrusive<Message>(
|
||||
RemoteRet(rrefId, forkId).toMessage());
|
||||
},
|
||||
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processScriptRemoteCall(
|
||||
ScriptRemoteCall& scriptRemoteCall,
|
||||
std::vector<at::IValue>& stack) const {
|
||||
RpcCommandBase& rpc) const {
|
||||
auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
|
||||
|
||||
TORCH_CHECK(
|
||||
scriptRemoteCall.hasOp(), "ScriptRemoteCall needs to have an op!");
|
||||
return processScriptRemoteCallOp(scriptRemoteCall, stack);
|
||||
}
|
||||
auto future =
|
||||
runJitOperator(*scriptRemoteCall.op(), scriptRemoteCall.stackRef());
|
||||
|
||||
void RequestCallbackNoPython::processBaseScriptRemoteCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
const c10::intrusive_ptr<JitFuture>& responseFuture) const {
|
||||
auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
|
||||
auto& stack = scriptRemoteCall.stackRef();
|
||||
auto jitFuture = processScriptRemoteCall(scriptRemoteCall, stack);
|
||||
|
||||
assignOwnerRRef(
|
||||
return assignOwnerRRef(
|
||||
scriptRemoteCall.retRRefId(),
|
||||
scriptRemoteCall.retForkId(),
|
||||
std::move(jitFuture),
|
||||
responseFuture,
|
||||
std::move(future),
|
||||
/*lsctx=*/nullptr);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
|
||||
processScriptRemoteCallOp(
|
||||
ScriptRemoteCall& scriptRemoteCall,
|
||||
std::vector<at::IValue>& stack) const {
|
||||
TORCH_INTERNAL_ASSERT(scriptRemoteCall.hasOp());
|
||||
return runJitOperator(*scriptRemoteCall.op(), stack);
|
||||
}
|
||||
|
||||
void RequestCallbackNoPython::processScriptRRefFetchCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
@ -556,21 +529,16 @@ c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRpc(
|
||||
// to a python object.
|
||||
switch (messageType) {
|
||||
case MessageType::SCRIPT_CALL: {
|
||||
processScriptCall(rpc, markComplete, responseFuture);
|
||||
return responseFuture;
|
||||
return processScriptCall(rpc);
|
||||
}
|
||||
case MessageType::PYTHON_CALL: {
|
||||
processPythonCall(rpc, markComplete, responseFuture);
|
||||
return responseFuture;
|
||||
return processPythonCall(rpc);
|
||||
}
|
||||
case MessageType::SCRIPT_REMOTE_CALL: {
|
||||
processBaseScriptRemoteCall(rpc, markComplete, responseFuture);
|
||||
return responseFuture;
|
||||
return processScriptRemoteCall(rpc);
|
||||
}
|
||||
case MessageType::PYTHON_REMOTE_CALL: {
|
||||
processPythonRemoteCall(
|
||||
rpc, markComplete, responseFuture, std::move(ctx));
|
||||
return responseFuture;
|
||||
return processPythonRemoteCall(rpc, std::move(ctx));
|
||||
}
|
||||
case MessageType::SCRIPT_RREF_FETCH_CALL: {
|
||||
processScriptRRefFetchCall(rpc, markComplete, responseFuture);
|
||||
|
@ -23,45 +23,23 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback {
|
||||
std::unique_ptr<RpcCommandBase> rpc,
|
||||
const MessageType& messageType) const;
|
||||
|
||||
virtual void processScriptCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
const c10::intrusive_ptr<JitFuture>& responseFuture) const;
|
||||
virtual c10::intrusive_ptr<JitFuture> processScriptCall(
|
||||
RpcCommandBase& rpc) const;
|
||||
|
||||
void processScriptCallOp(
|
||||
ScriptCall& scriptCall,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
std::vector<at::IValue>& stack) const;
|
||||
virtual c10::intrusive_ptr<JitFuture> processPythonCall(
|
||||
RpcCommandBase& rpc) const;
|
||||
|
||||
virtual void processPythonCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
const c10::intrusive_ptr<JitFuture>& responseFuture) const;
|
||||
|
||||
void assignOwnerRRef(
|
||||
c10::intrusive_ptr<JitFuture> assignOwnerRRef(
|
||||
const RRefId& rrefId,
|
||||
const RRefId& forkId,
|
||||
c10::intrusive_ptr<JitFuture> valueFuture,
|
||||
const c10::intrusive_ptr<JitFuture>& responseFuture,
|
||||
std::shared_ptr<LazyStreamContext> lsctx) const;
|
||||
|
||||
virtual c10::intrusive_ptr<JitFuture> processScriptRemoteCall(
|
||||
ScriptRemoteCall& scriptRemoteCall,
|
||||
std::vector<at::IValue>& stack) const;
|
||||
RpcCommandBase& rpc) const;
|
||||
|
||||
void processBaseScriptRemoteCall(
|
||||
virtual c10::intrusive_ptr<JitFuture> processPythonRemoteCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
const c10::intrusive_ptr<JitFuture>& responseFuture) const;
|
||||
|
||||
c10::intrusive_ptr<JitFuture> processScriptRemoteCallOp(
|
||||
ScriptRemoteCall& scriptRemoteCall,
|
||||
std::vector<at::IValue>& stack) const;
|
||||
|
||||
virtual void processPythonRemoteCall(
|
||||
RpcCommandBase& rpc,
|
||||
const std::function<void(Message)>& markComplete,
|
||||
const c10::intrusive_ptr<JitFuture>& responseFuture,
|
||||
std::shared_ptr<LazyStreamContext> ctx) const;
|
||||
|
||||
void processScriptRRefFetchCall(
|
||||
|
Reference in New Issue
Block a user