#include #include #include #include #include #include #include #include #include namespace torch { namespace distributed { namespace rpc { std::unique_ptr deserializeRequest(const Message& request) { switch (request.type()) { case MessageType::SCRIPT_CALL: { return ScriptCall::fromMessage(request); } case MessageType::PYTHON_CALL: { return PythonUDFCall::fromMessage(request); } case MessageType::SCRIPT_REMOTE_CALL: { return ScriptRemoteCall::fromMessage(request); } case MessageType::PYTHON_REMOTE_CALL: { return PythonRemoteCall::fromMessage(request); } case MessageType::SCRIPT_RREF_FETCH_CALL: { return ScriptRRefFetchCall::fromMessage(request); } case MessageType::PYTHON_RREF_FETCH_CALL: { return PythonRRefFetchCall::fromMessage(request); } case MessageType::RREF_USER_DELETE: { return RRefUserDelete::fromMessage(request); } case MessageType::RREF_CHILD_ACCEPT: { return RRefChildAccept::fromMessage(request); } case MessageType::RREF_FORK_REQUEST: { return RRefForkRequest::fromMessage(request); } case MessageType::MESSAGE_WITH_AUTOGRAD_REQ: { return RpcWithAutograd::fromMessage(request); } default: { TORCH_INTERNAL_ASSERT( false, "Request type ", request.type(), " not supported."); } } } std::unique_ptr deserializeResponse(const Message& response) { switch (response.type()) { case MessageType::SCRIPT_RET: { return ScriptResp::fromMessage(response); } case MessageType::PYTHON_RET: { return PythonUDFResp::fromMessage(response); } case MessageType::REMOTE_RET: { return RemoteRet::fromMessage(response); } case MessageType::RREF_FETCH_RET: { return RRefFetchRet::fromMessage(response); } case MessageType::RREF_ACK: { return RRefAck::fromMessage(response); } case MessageType::EXCEPTION: { std::string err(response.payload().begin(), response.payload().end()); throw std::runtime_error(err); } case MessageType::MESSAGE_WITH_AUTOGRAD_RESP: { return RpcWithAutograd::fromMessage(response); } default: { TORCH_INTERNAL_ASSERT( false, "Response type ", response.type(), " not supported."); } } } } // namespace rpc } // namespace distributed } // namespace torch