mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25527 Master GH issue: https://github.com/pytorch/pytorch/issues/23110. This change builds upon https://github.com/pytorch/pytorch/pull/24876 and provides all the autograd hooks needed for a forward pass with distributed rpc for builtin operators. This change does not address distributed rpc for python UDFs and that will be addressed in follow up PRs. Summary of changes: 1. Attach send autograd functions when a request is sent from the client and response is sent from the server. 2. Attach receive autograd functions when a request is received on the server and a response is received on the client. 3. Generate a globally unique autograd_message_id for each send/recv autograd function pair to uniquely identify them. ghstack-source-id: 91240466 Test Plan: unit tests. Differential Revision: D17148077 fbshipit-source-id: 192d8a3f552ed7cc939f55dcca332965c9bd3233
164 lines
5.4 KiB
C++
164 lines
5.4 KiB
C++
#include <torch/csrc/distributed/rpc/rpc_with_autograd.h>
|
|
#include <c10/util/C++17.h>
|
|
#include <torch/csrc/distributed/rpc/utils.h>
|
|
#include <torch/csrc/utils/byte_order.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
constexpr int kAutogradMessageSize = 17;
|
|
|
|
AutogradMetadata::AutogradMetadata(
|
|
int64_t autogradContextId_,
|
|
int64_t autogradMessageId_)
|
|
: autogradContextId(autogradContextId_),
|
|
autogradMessageId(autogradMessageId_) {}
|
|
|
|
RpcWithAutograd::RpcWithAutograd(
|
|
MessageType messageType,
|
|
const AutogradMetadata& autogradMetadata,
|
|
std::unique_ptr<RpcCommandBase> wrappedRpc)
|
|
: messageType_(messageType), autogradMetadata_(autogradMetadata) {
|
|
TORCH_INTERNAL_ASSERT(wrappedRpc != nullptr, "wrappedRpc cannot be null!");
|
|
TORCH_INTERNAL_ASSERT(
|
|
messageType_ == MessageType::MESSAGE_WITH_AUTOGRAD_REQ ||
|
|
messageType_ == MessageType::MESSAGE_WITH_AUTOGRAD_RESP);
|
|
wrappedMessage_ = std::move(*wrappedRpc).toMessage();
|
|
tensors_ = wrappedMessage_.tensors();
|
|
wrappedMessageType_ = wrappedMessage_.type();
|
|
}
|
|
|
|
RpcWithAutograd::RpcWithAutograd(
|
|
MessageType messageType,
|
|
const AutogradMetadata& autogradMetadata,
|
|
std::unique_ptr<RpcCommandBase> wrappedRpc,
|
|
MessageType wrappedMessageType,
|
|
std::vector<torch::Tensor> tensors)
|
|
: messageType_(messageType),
|
|
autogradMetadata_(autogradMetadata),
|
|
wrappedRpc_(std::move(wrappedRpc)),
|
|
wrappedMessageType_(wrappedMessageType),
|
|
tensors_(std::move(tensors)) {
|
|
TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
|
|
TORCH_INTERNAL_ASSERT(
|
|
messageType_ == MessageType::MESSAGE_WITH_AUTOGRAD_REQ ||
|
|
messageType_ == MessageType::MESSAGE_WITH_AUTOGRAD_RESP);
|
|
}
|
|
|
|
Message RpcWithAutograd::toMessage() && {
|
|
auto messageId = wrappedMessage_.id();
|
|
auto messageType = wrappedMessage_.type();
|
|
|
|
auto payload = std::move(wrappedMessage_).movePayload();
|
|
TORCH_INTERNAL_ASSERT(!payload.empty());
|
|
|
|
// We append the message type (1 byte), autograd context id(8 bytes) and
|
|
// autograd message id(8 bytes) to the original message in network byte order
|
|
// (big endian).
|
|
size_t writableIndex = payload.size();
|
|
|
|
// Need 17 additional bytes.
|
|
payload.resize(payload.size() + kAutogradMessageSize);
|
|
|
|
// Add message type.
|
|
payload[writableIndex++] = messageType;
|
|
|
|
// Add autograd ids.
|
|
torch::utils::THP_encodeInt64Buffer(
|
|
reinterpret_cast<uint8_t*>(payload.data()) + writableIndex,
|
|
&autogradMetadata_.autogradContextId,
|
|
torch::utils::THPByteOrder::THP_BIG_ENDIAN,
|
|
1);
|
|
writableIndex += sizeof(int64_t);
|
|
torch::utils::THP_encodeInt64Buffer(
|
|
reinterpret_cast<uint8_t*>(payload.data()) + writableIndex,
|
|
&autogradMetadata_.autogradMessageId,
|
|
torch::utils::THPByteOrder::THP_BIG_ENDIAN,
|
|
1);
|
|
|
|
return Message(
|
|
std::move(payload), std::move(tensors_), messageType_, messageId);
|
|
}
|
|
|
|
std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
|
|
const Message& message) {
|
|
MessageType originalMessageType = message.type();
|
|
TORCH_INTERNAL_ASSERT(
|
|
MessageType::MESSAGE_WITH_AUTOGRAD_REQ == originalMessageType ||
|
|
MessageType::MESSAGE_WITH_AUTOGRAD_RESP == originalMessageType);
|
|
|
|
std::vector<torch::Tensor> tensors = message.tensors();
|
|
int64_t messageId = message.id();
|
|
// Decode message type, autograd context id and autograd message id.
|
|
auto payload = message.payload();
|
|
TORCH_INTERNAL_ASSERT(payload.size() > kAutogradMessageSize);
|
|
|
|
int64_t autogradContextId, autogradMessageId;
|
|
// autograd message id.
|
|
size_t indexToRead = payload.size() - sizeof(int64_t);
|
|
TORCH_INTERNAL_ASSERT(indexToRead >= 0);
|
|
torch::utils::THP_decodeInt64Buffer(
|
|
&autogradMessageId,
|
|
reinterpret_cast<uint8_t*>(payload.data()) + indexToRead,
|
|
torch::utils::THPByteOrder::THP_BIG_ENDIAN,
|
|
1);
|
|
|
|
// autograd context id.
|
|
indexToRead -= sizeof(int64_t);
|
|
TORCH_INTERNAL_ASSERT(indexToRead >= 0);
|
|
torch::utils::THP_decodeInt64Buffer(
|
|
&autogradContextId,
|
|
reinterpret_cast<uint8_t*>(payload.data()) + indexToRead,
|
|
torch::utils::THPByteOrder::THP_BIG_ENDIAN,
|
|
1);
|
|
|
|
// message type.
|
|
indexToRead -= 1;
|
|
TORCH_INTERNAL_ASSERT(indexToRead >= 0);
|
|
MessageType wrappedMessageType =
|
|
static_cast<MessageType>(payload[indexToRead]);
|
|
|
|
// Remove the autograd information.
|
|
payload.resize(payload.size() - kAutogradMessageSize);
|
|
|
|
// Create new message type and build wrapped RPC.
|
|
Message wrappedMessage(
|
|
std::move(payload), std::move(tensors), wrappedMessageType, messageId);
|
|
|
|
std::unique_ptr<RpcCommandBase> wrappedRpc;
|
|
if (originalMessageType == MessageType::MESSAGE_WITH_AUTOGRAD_REQ) {
|
|
wrappedRpc = deserializeRequest(wrappedMessage);
|
|
} else {
|
|
wrappedRpc = deserializeResponse(wrappedMessage);
|
|
}
|
|
|
|
return c10::guts::make_unique<RpcWithAutograd>(
|
|
originalMessageType,
|
|
AutogradMetadata(autogradContextId, autogradMessageId),
|
|
std::move(wrappedRpc),
|
|
wrappedMessageType,
|
|
std::move(tensors));
|
|
}
|
|
|
|
std::vector<torch::Tensor>& RpcWithAutograd::tensors() {
|
|
return tensors_;
|
|
}
|
|
|
|
const AutogradMetadata& RpcWithAutograd::autogradMetadata() const {
|
|
return autogradMetadata_;
|
|
}
|
|
|
|
RpcCommandBase& RpcWithAutograd::wrappedRpc() {
|
|
TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
|
|
return *wrappedRpc_;
|
|
}
|
|
|
|
MessageType RpcWithAutograd::wrappedMessageType() const {
|
|
return wrappedMessageType_;
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|