Files
pytorch/torch/csrc/distributed/rpc/request_callback_impl.h
Luca Wehrstedt 20d02cb7dd Remove getScriptRemoteCallType (#57854)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57854

Because OwnerRRefs used to be created before their value was computed, we had to figure out their type ahead of time. After the previous diff, we inverted the order of operations, and we can now first compute the result and then create the OwnerRRef. Which means we can just inspect the value to get its type. Much simpler, and much less likely to get it wrong.
ghstack-source-id: 129567060

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D28253843

fbshipit-source-id: f13c9b294f477ae66fcbdbc85c642fdc69b2740f
2021-05-21 13:15:07 -07:00

66 lines
2.1 KiB
C++

#pragma once
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/request_callback_no_python.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
namespace torch {
namespace distributed {
namespace rpc {
class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython {
public:
std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand(
std::unique_ptr<RpcCommandBase> rpc,
const MessageType& messageType) const override;
void processPythonCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
void processScriptCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
c10::intrusive_ptr<JitFuture> processScriptRemoteCall(
ScriptRemoteCall& scriptRemoteCall,
std::vector<at::IValue>& stack) const override;
void processPythonRemoteCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> ctx) const override;
void processPythonRRefFetchCall(
RpcCommandBase& rpc,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> ctx) const override;
void handleRRefDelete(c10::intrusive_ptr<RRef>& rref) const override;
c10::intrusive_ptr<JitFuture> processRpcWithErrors(
RpcCommandBase& rpc,
const MessageType& messageType,
std::shared_ptr<LazyStreamContext> ctx) const override;
bool cudaAvailable() const override;
void processRRefBackward(
RpcCommandBase& rpc,
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
// Helpers to run user-defined functions, operators and other computations.
c10::intrusive_ptr<JitFuture> runJitFunction(
const c10::QualifiedName& name,
std::vector<at::IValue>& stack,
bool isAsyncExecution) const;
};
} // namespace rpc
} // namespace distributed
} // namespace torch