mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57854 Because OwnerRRefs used to be created before their value was computed, we had to figure out their type ahead of time. After the previous diff, we inverted the order of operations, and we can now first compute the result and then create the OwnerRRef. Which means we can just inspect the value to get its type. Much simpler, and much less likely to get it wrong. ghstack-source-id: 129567060 Test Plan: CI Reviewed By: mrshenli Differential Revision: D28253843 fbshipit-source-id: f13c9b294f477ae66fcbdbc85c642fdc69b2740f
66 lines
2.1 KiB
C++
66 lines
2.1 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/distributed/rpc/message.h>
|
|
#include <torch/csrc/distributed/rpc/request_callback_no_python.h>
|
|
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython {
|
|
public:
|
|
std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand(
|
|
std::unique_ptr<RpcCommandBase> rpc,
|
|
const MessageType& messageType) const override;
|
|
|
|
void processPythonCall(
|
|
RpcCommandBase& rpc,
|
|
const std::function<void(Message)>& markComplete,
|
|
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
|
|
|
|
void processScriptCall(
|
|
RpcCommandBase& rpc,
|
|
const std::function<void(Message)>& markComplete,
|
|
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
|
|
|
|
c10::intrusive_ptr<JitFuture> processScriptRemoteCall(
|
|
ScriptRemoteCall& scriptRemoteCall,
|
|
std::vector<at::IValue>& stack) const override;
|
|
|
|
void processPythonRemoteCall(
|
|
RpcCommandBase& rpc,
|
|
const std::function<void(Message)>& markComplete,
|
|
const c10::intrusive_ptr<JitFuture>& responseFuture,
|
|
std::shared_ptr<LazyStreamContext> ctx) const override;
|
|
|
|
void processPythonRRefFetchCall(
|
|
RpcCommandBase& rpc,
|
|
const c10::intrusive_ptr<JitFuture>& responseFuture,
|
|
std::shared_ptr<LazyStreamContext> ctx) const override;
|
|
|
|
void handleRRefDelete(c10::intrusive_ptr<RRef>& rref) const override;
|
|
|
|
c10::intrusive_ptr<JitFuture> processRpcWithErrors(
|
|
RpcCommandBase& rpc,
|
|
const MessageType& messageType,
|
|
std::shared_ptr<LazyStreamContext> ctx) const override;
|
|
|
|
bool cudaAvailable() const override;
|
|
|
|
void processRRefBackward(
|
|
RpcCommandBase& rpc,
|
|
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
|
|
|
|
// Helpers to run user-defined functions, operators and other computations.
|
|
|
|
c10::intrusive_ptr<JitFuture> runJitFunction(
|
|
const c10::QualifiedName& name,
|
|
std::vector<at::IValue>& stack,
|
|
bool isAsyncExecution) const;
|
|
};
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|