Files
pytorch/torch/csrc/distributed/rpc/request_callback_impl.h
Luca Wehrstedt cd9dbbd93a Simplify process(Script|Python)(Remote)?Call (#57857)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57857

There used to be a whole lot of methods: `processPythonCall`, `processScriptCall`, `processScriptRemoteCall`, `processPythonRemoteCall`, `processScriptCallOp`, `processBaseScriptRemoteCall` and `processScriptRemoteCallOp`. Thanks to the previous simplification, we can now drop all but the first four, which map nicely 1:1 to the four message types we need to handle. Also their signatures become much simpler: they take an RPC command and return a future.
ghstack-source-id: 129567070

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D28253848

fbshipit-source-id: e0e45345c414a96900f9d70ee555359d28908833
2021-05-21 13:15:12 -07:00

64 lines
2.0 KiB
C++

#pragma once
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/request_callback_no_python.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/jit/python/pybind.h>
namespace torch {
namespace distributed {
namespace rpc {
class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython {
public:
std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand(
std::unique_ptr<RpcCommandBase> rpc,
const MessageType& messageType) const override;
c10::intrusive_ptr<JitFuture> processPythonCall(
RpcCommandBase& rpc) const override;
c10::intrusive_ptr<JitFuture> processScriptCall(
RpcCommandBase& rpc) const override;
c10::intrusive_ptr<JitFuture> processScriptRemoteCall(
RpcCommandBase& rpc) const override;
c10::intrusive_ptr<JitFuture> processPythonRemoteCall(
RpcCommandBase& rpc,
std::shared_ptr<LazyStreamContext> ctx) const override;
void processPythonRRefFetchCall(
RpcCommandBase& rpc,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> ctx) const override;
void handleRRefDelete(c10::intrusive_ptr<RRef>& rref) const override;
c10::intrusive_ptr<JitFuture> processRpcWithErrors(
RpcCommandBase& rpc,
const MessageType& messageType,
std::shared_ptr<LazyStreamContext> ctx) const override;
bool cudaAvailable() const override;
void processRRefBackward(
RpcCommandBase& rpc,
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
// Helpers to run user-defined functions, operators and other computations.
c10::intrusive_ptr<JitFuture> runJitFunction(
const c10::QualifiedName& name,
std::vector<at::IValue>& stack,
bool isAsyncExecution) const;
c10::intrusive_ptr<JitFuture> runPythonFunction(
const py::object& function,
bool isAsyncExecution) const;
};
} // namespace rpc
} // namespace distributed
} // namespace torch