mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57857 There used to be a whole lot of methods: `processPythonCall`, `processScriptCall`, `processScriptRemoteCall`, `processPythonRemoteCall`, `processScriptCallOp`, `processBaseScriptRemoteCall` and `processScriptRemoteCallOp`. Thanks to the previous simplification, we can now drop all but the first four, which map nicely 1:1 to the four message types we need to handle. Also their signatures become much simpler: they take an RPC command and return a future. ghstack-source-id: 129567070 Test Plan: CI Reviewed By: mrshenli Differential Revision: D28253848 fbshipit-source-id: e0e45345c414a96900f9d70ee555359d28908833
64 lines
2.0 KiB
C++
64 lines
2.0 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/distributed/rpc/message.h>
|
|
#include <torch/csrc/distributed/rpc/request_callback_no_python.h>
|
|
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
|
#include <torch/csrc/jit/python/pybind.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython {
|
|
public:
|
|
std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand(
|
|
std::unique_ptr<RpcCommandBase> rpc,
|
|
const MessageType& messageType) const override;
|
|
|
|
c10::intrusive_ptr<JitFuture> processPythonCall(
|
|
RpcCommandBase& rpc) const override;
|
|
|
|
c10::intrusive_ptr<JitFuture> processScriptCall(
|
|
RpcCommandBase& rpc) const override;
|
|
|
|
c10::intrusive_ptr<JitFuture> processScriptRemoteCall(
|
|
RpcCommandBase& rpc) const override;
|
|
|
|
c10::intrusive_ptr<JitFuture> processPythonRemoteCall(
|
|
RpcCommandBase& rpc,
|
|
std::shared_ptr<LazyStreamContext> ctx) const override;
|
|
|
|
void processPythonRRefFetchCall(
|
|
RpcCommandBase& rpc,
|
|
const c10::intrusive_ptr<JitFuture>& responseFuture,
|
|
std::shared_ptr<LazyStreamContext> ctx) const override;
|
|
|
|
void handleRRefDelete(c10::intrusive_ptr<RRef>& rref) const override;
|
|
|
|
c10::intrusive_ptr<JitFuture> processRpcWithErrors(
|
|
RpcCommandBase& rpc,
|
|
const MessageType& messageType,
|
|
std::shared_ptr<LazyStreamContext> ctx) const override;
|
|
|
|
bool cudaAvailable() const override;
|
|
|
|
void processRRefBackward(
|
|
RpcCommandBase& rpc,
|
|
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
|
|
|
|
// Helpers to run user-defined functions, operators and other computations.
|
|
|
|
c10::intrusive_ptr<JitFuture> runJitFunction(
|
|
const c10::QualifiedName& name,
|
|
std::vector<at::IValue>& stack,
|
|
bool isAsyncExecution) const;
|
|
|
|
c10::intrusive_ptr<JitFuture> runPythonFunction(
|
|
const py::object& function,
|
|
bool isAsyncExecution) const;
|
|
};
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|