Files
pytorch/torch/csrc/distributed/rpc/rref_proto.cpp
Scott Wolchok e88d1c4f10 [PyTorch] Add tuple inline storage (#64066)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64066

I noticed a bunch of time being spent heap-allocating Tuples
in the unpickler. 1-, 2-, and 3-element Tuples are apparently common
enough that they get their own bytecode instructions, so I decided to
try also giving them their own representation. We store up to 3
IValues inline in `Tuple` rather than doing a second heap allocation
for a `std::vector<IValue>`.
ghstack-source-id: 140695395

Test Plan:
Added automated tests for TupleElements.

Pixel 3 before: https://www.internalfb.com/intern/aibench/details/761596366576284
Pixel 3 after: https://www.internalfb.com/intern/aibench/details/591414145082422
We went from 347 ms to 302 ms.

Reviewed By: dhruvbird

Differential Revision: D30592622

fbshipit-source-id: 93625c54c9dca5f765ef6d5c191944179cb281a8
2021-10-15 12:16:51 -07:00

195 lines
6.3 KiB
C++

#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <limits>
namespace torch {
namespace distributed {
namespace rpc {
namespace {
c10::ivalue::TupleElements toIValues(const Message& message, MessageType type) {
TORCH_INTERNAL_ASSERT(
type == message.type(),
"Expecting message of type ",
type,
", but got ",
message.type());
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value = jit::unpickle(
payload,
payload_size,
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
message.tensors());
return std::move(*std::move(value).toTuple()).elements();
}
c10::intrusive_ptr<Message> fromIValues(
std::vector<IValue> ivalues,
MessageType type) {
std::vector<torch::Tensor> tensor_table;
auto payload = jit::pickle(
c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
return c10::make_intrusive<Message>(
std::move(payload), std::move(tensor_table), type);
}
} // namespace
/////////////////////////// RRefMessageBase //////////////////////////////////
const RRefId& RRefMessageBase::rrefId() {
return rrefId_;
}
/////////////////////////// ForkMessageBase //////////////////////////////////
const ForkId& ForkMessageBase::forkId() {
return forkId_;
}
c10::intrusive_ptr<Message> ForkMessageBase::toMessageImpl() && {
return fromIValues({rrefId_.toIValue(), forkId_.toIValue()}, type_);
}
std::pair<RRefId, ForkId> ForkMessageBase::fromMessage(
const Message& message,
MessageType type) {
auto ivalues = toIValues(message, type);
TORCH_INTERNAL_ASSERT(
ivalues.size() == 2, "ForkMessageBase expects 2 IValue from message.");
return std::make_pair(
RRefId::fromIValue(ivalues[0]), ForkId::fromIValue(ivalues[1]));
}
/////////////////////////// RRef Protocol //////////////////////////////////
c10::intrusive_ptr<Message> ScriptRRefFetchCall::toMessageImpl() && {
std::vector<at::IValue> ivalues;
ivalues.reserve(2);
ivalues.emplace_back(rrefId_.toIValue());
ivalues.emplace_back(fromWorkerId_);
return fromIValues(std::move(ivalues), MessageType::SCRIPT_RREF_FETCH_CALL);
}
std::unique_ptr<ScriptRRefFetchCall> ScriptRRefFetchCall::fromMessage(
const Message& message) {
auto values = toIValues(message, MessageType::SCRIPT_RREF_FETCH_CALL);
TORCH_INTERNAL_ASSERT(
values.size() == 2, "ScriptRRefFetchCall expects 2 IValues from message");
auto id = values[1].toInt();
TORCH_INTERNAL_ASSERT(
id >= std::numeric_limits<worker_id_t>::min() &&
id <= std::numeric_limits<worker_id_t>::max(),
"ScriptRRefFetchCall fromWorkerId exceeds worker_id_t limit.")
return std::make_unique<ScriptRRefFetchCall>(
worker_id_t(id), RRefId::fromIValue(values[0]));
}
c10::intrusive_ptr<Message> PythonRRefFetchCall::toMessageImpl() && {
std::vector<at::IValue> ivalues;
ivalues.reserve(2);
ivalues.emplace_back(rrefId_.toIValue());
ivalues.emplace_back(fromWorkerId_);
return fromIValues(std::move(ivalues), MessageType::PYTHON_RREF_FETCH_CALL);
}
std::unique_ptr<PythonRRefFetchCall> PythonRRefFetchCall::fromMessage(
const Message& message) {
auto values = toIValues(message, MessageType::PYTHON_RREF_FETCH_CALL);
TORCH_INTERNAL_ASSERT(
values.size() == 2, "PythonRRefFetchCall expects 2 IValues from message");
auto id = values[1].toInt();
TORCH_INTERNAL_ASSERT(
id >= std::numeric_limits<worker_id_t>::min() &&
id <= std::numeric_limits<worker_id_t>::max(),
"PythonRRefFetchCall fromWorkerId exceeds worker_id_t limit.")
return std::make_unique<PythonRRefFetchCall>(
worker_id_t(id), RRefId::fromIValue(values[0]));
}
const std::vector<at::IValue>& RRefFetchRet::values() {
return values_;
}
c10::intrusive_ptr<Message> RRefFetchRet::toMessageImpl() && {
return fromIValues(values_, type_);
}
std::unique_ptr<ScriptRRefFetchRet> ScriptRRefFetchRet::fromMessage(
const Message& message) {
auto values = toIValues(message, MessageType::SCRIPT_RREF_FETCH_RET);
TORCH_INTERNAL_ASSERT(
values.size() == 1,
"RRef of IValue should contain a single IValue, but got ",
values.size());
return std::make_unique<ScriptRRefFetchRet>(std::move(values).vec());
}
std::unique_ptr<PythonRRefFetchRet> PythonRRefFetchRet::fromMessage(
const Message& message) {
return std::make_unique<PythonRRefFetchRet>(
toIValues(message, MessageType::PYTHON_RREF_FETCH_RET).vec());
}
std::unique_ptr<RRefUserDelete> RRefUserDelete::fromMessage(
const Message& message) {
auto pair =
ForkMessageBase::fromMessage(message, MessageType::RREF_USER_DELETE);
return std::make_unique<RRefUserDelete>(
RRefUserDelete(pair.first, pair.second));
}
std::unique_ptr<RemoteRet> RemoteRet::fromMessage(const Message& message) {
auto pair = ForkMessageBase::fromMessage(message, MessageType::REMOTE_RET);
return std::make_unique<RemoteRet>(pair.first, pair.second);
}
const ForkId& RRefChildAccept::forkId() const {
return forkId_;
}
c10::intrusive_ptr<Message> RRefChildAccept::toMessageImpl() && {
return fromIValues({forkId_.toIValue()}, MessageType::RREF_CHILD_ACCEPT);
}
std::unique_ptr<RRefChildAccept> RRefChildAccept::fromMessage(
const Message& message) {
auto values = toIValues(message, MessageType::RREF_CHILD_ACCEPT);
TORCH_INTERNAL_ASSERT(values.size() == 1, "Expect 1 IValues from message.");
return std::make_unique<RRefChildAccept>(ForkId::fromIValue(values.back()));
}
std::unique_ptr<RRefForkRequest> RRefForkRequest::fromMessage(
const Message& message) {
auto pair =
ForkMessageBase::fromMessage(message, MessageType::RREF_FORK_REQUEST);
return std::make_unique<RRefForkRequest>(pair.first, pair.second);
}
c10::intrusive_ptr<Message> RRefAck::toMessageImpl() && {
return c10::make_intrusive<Message>(
std::vector<char>{}, std::vector<torch::Tensor>{}, MessageType::RREF_ACK);
}
std::unique_ptr<RRefAck> RRefAck::fromMessage(const Message& message) {
TORCH_INTERNAL_ASSERT(
message.type() == MessageType::RREF_ACK,
"Message type miss match, expect ",
MessageType::RREF_ACK,
", but got ",
message.type());
return std::make_unique<RRefAck>();
}
} // namespace rpc
} // namespace distributed
} // namespace torch