From f4e7e9039d27703a90bee8914d77fe31ccc27b1e Mon Sep 17 00:00:00 2001 From: Jeremy Lilley Date: Thu, 28 Nov 2019 09:55:33 -0800 Subject: [PATCH] Improve process_group_agent() serialization speed (#29785) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29785 TLDR: This change improves process_group's serialization speed: Serialize_Tensor64: 12.38us -> 1.99us (~-84%) Deserialize_Tensor64: 33.89us -> 5.62us (~-84%) Serialize_Tensor1M: 525.74us -> 285.43us (~-45%) Deserialize_Tensor1M: 892.61us -> 273.68us (~-70%) After speaking with the jit team, we had consensus that torch::save()/load() are somewhat high-overhead for RPC serialization, mostly intended for persistent disk data. (Particularly, for large tensors, 35% of the time is spent in CRC checking, even with the fb-side changes to subsitute 40x faster SSE-accelerated crc checking; Also, for small tensors, the zip container overhead is considerable, as is the overhead of lexing/parsing an embedded text python program for each RPC). The jit team encouraged us to use jit::pickler, with the WriteableTensorData way of outputting result tensors (not the default side-tensor table, or with pickling the actual tensors). This ends up just pickling some tensor metadata, and giving us some tensor blobs that we can mindlessly blit over the wire (they copy to cpu memory if needed). There is yet no standardized container format for the pickled data (there is jit::pickle_save() checked in, but but it's experimental, no load function is yet provided), but they encouraged us to just use something sensible for this, and possibly revisit later. For now, I made the directory headers slightly http-inspired. Note that serialization is just one component of the pipeline, but that said, we also see reasonable reductions in end-to-end echo times (noisier): ProcessGroupAgent_Echo(Tensor_Small) 855.25us -> 492.65us (~-42%) ProcessGroupAgent_Echo(Tensor_1M) 10.82ms -> 6.94ms (~-35%) ProcessGroupAgent_Echo(Small_NoTensor) 688.82us -> 301.72us (~-56%) ProcessGroupAgent_Echo(1MB_NoTensor) 4.65ms -> 3.71ms (~-20%) I moved the "wire serialization" logic to a separate file to assist with unittesting. ghstack-source-id: 94694682 Test Plan: buck test mode/dev-nosan caffe2/test/cpp/api:serialize buck test mode/dev-nosan caffe2/test/... Differential Revision: D18493938 fbshipit-source-id: 07ddfe87dbe56472bc944f7d070627052c94a8f4 --- caffe2/CMakeLists.txt | 1 + test/cpp/rpc/CMakeLists.txt | 27 +++ test/cpp/rpc/test_wire_serialization.cpp | 39 ++++ .../distributed/rpc/process_group_agent.cpp | 85 ++------- .../distributed/rpc/process_group_agent.h | 9 +- torch/csrc/distributed/rpc/utils.cpp | 180 ++++++++++++++++++ torch/csrc/distributed/rpc/utils.h | 10 + 7 files changed, 283 insertions(+), 68 deletions(-) create mode 100644 test/cpp/rpc/CMakeLists.txt create mode 100644 test/cpp/rpc/test_wire_serialization.cpp diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 007ad3de349c..56c16495b39e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -745,6 +745,7 @@ ENDIF() if (BUILD_TEST AND NOT MSVC AND NOT USE_ROCM) add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) + add_subdirectory(${TORCH_ROOT}/test/cpp/rpc ${CMAKE_BINARY_DIR}/test_cpp_rpc) endif() if (BUILD_TEST AND NOT NO_API) diff --git a/test/cpp/rpc/CMakeLists.txt b/test/cpp/rpc/CMakeLists.txt new file mode 100644 index 000000000000..929db9094226 --- /dev/null +++ b/test/cpp/rpc/CMakeLists.txt @@ -0,0 +1,27 @@ +set(TORCH_RPC_TEST_DIR "${TORCH_ROOT}/test/cpp/rpc") +set(TORCH_RPC_TEST_SOURCES + ${TORCH_ROOT}/test/cpp/common/main.cpp + ${TORCH_RPC_TEST_DIR}/test_wire_serialization.cpp +) + +add_executable(test_cpp_rpc ${TORCH_RPC_TEST_SOURCES}) +target_include_directories(test_cpp_rpc PRIVATE ${ATen_CPU_INCLUDE}) +target_link_libraries(test_cpp_rpc PRIVATE torch gtest) + +if (USE_CUDA) + target_link_libraries(test_cpp_rpc PRIVATE + ${CUDA_LIBRARIES} + ${CUDA_NVRTC_LIB} + ${CUDA_CUDA_LIB} + ${TORCH_CUDA_LIBRARIES}) + + target_compile_definitions(test_cpp_rpc PRIVATE "USE_CUDA") +endif() + +if (INSTALL_TEST) + install(TARGETS test_cpp_rpc DESTINATION bin) + # Install PDB files for MSVC builds + if (MSVC AND BUILD_SHARED_LIBS) + install(FILES $ DESTINATION bin OPTIONAL) + endif() +endif() diff --git a/test/cpp/rpc/test_wire_serialization.cpp b/test/cpp/rpc/test_wire_serialization.cpp new file mode 100644 index 000000000000..e1631eb20cbc --- /dev/null +++ b/test/cpp/rpc/test_wire_serialization.cpp @@ -0,0 +1,39 @@ +#include + +#include +#include + +#include +#include +#include + +using namespace torch::distributed::rpc; + +TEST(WireSerialize, Base) { + auto run = [](const std::string& payload, + const std::vector& tensors) { + std::string serialized; + { + std::vector mpayload(payload.begin(), payload.end()); + std::vector mtensors = tensors; + serialized = torch::distributed::rpc::wireSerialize( + std::move(mpayload), std::move(mtensors)); + } + auto deser = torch::distributed::rpc::wireDeserialize( + serialized.data(), serialized.size()); + EXPECT_EQ(payload.size(), deser.first.size()); + EXPECT_EQ(tensors.size(), deser.second.size()); + if (payload.size() > 0) { + EXPECT_TRUE( + memcmp(deser.first.data(), payload.data(), payload.size()) == 0); + } + for (size_t i = 0; i < tensors.size(); ++i) { + EXPECT_TRUE(torch::equal(tensors[i], deser.second[i])); + } + }; + run("", {}); + run("hi", {}); + run("", {torch::randn({5, 5})}); + run("hi", {torch::randn({5, 5})}); + run("more", {torch::randn({5, 5}), torch::rand({10, 10})}); +} diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index a38992fb4150..b96957e02cc3 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include @@ -10,63 +11,6 @@ namespace torch { namespace distributed { namespace rpc { -namespace { - -// Write the message into the given ostream -std::string serialize(const Message& message) { - // We cast const void* to void* here because we need to create a tensor using - // that memory space. If is fine as that tensor stays function-local, and will - // not be modified during its lifetime. - auto payload = const_cast( // NOLINT - static_cast(message.payload().data())); - auto payload_size = message.payload().size(); - - // getting tensor table from the message - std::vector tensors = message.tensors(); - // append payload as a tensor - tensors.push_back(torch::from_blob(payload, payload_size, {torch::kChar})); - // append id and autograd metadata as a tensor - tensors.push_back(torch::tensor({message.id()}, {torch::kInt64})); - - // optional: estimate output size, to avoid some unnecessary resizing. - static constexpr size_t kBaseOverhead = 2048; - static constexpr size_t kPerTensor = 128; - size_t estimate = kBaseOverhead; - for (const auto& t : tensors) { - estimate += t.nbytes() + kPerTensor; - } - - std::string out; - out.reserve(estimate); - torch::save(tensors, [&](const void* buf, size_t n) -> size_t { - out.append(static_cast(buf), n); - return n; - }); - return out; -} - -Message deserialize(MessageType type, const void* buf, size_t size) { - std::vector tensors; - - torch::load(tensors, static_cast(buf), size); - - TORCH_CHECK(tensors.size() >= 2, "Failed to deserialize a message."); - auto idTensor = std::move(tensors.back()); - tensors.pop_back(); - auto payloadTensor = std::move(tensors.back()); - tensors.pop_back(); - - TORCH_INTERNAL_ASSERT(1, idTensor.numel()); - int64_t id = idTensor.storage().data()[0]; - - const char* data = static_cast(payloadTensor.storage().data()); - std::vector payload(data, data + payloadTensor.numel()); - - return Message(std::move(payload), std::move(tensors), type, id); -} - -} // namespace - ////////////////////////// MessageCounter ///////////////////////////////// ProcessGroupAgent::MessageCounter::MessageCounter(int worldSize) @@ -334,14 +278,15 @@ std::shared_ptr ProcessGroupAgent::send( // Unlike the other cases, need to add a tensor deleter, since the // data outlives the scope of this function. It's shared_ptr<> due // to c++11 lambda capture limitations with unique_ptr<>. - auto payload = - c10::guts::make_unique(serialize(message)); + auto payload = c10::guts::make_unique( + wireSerialize(message.payload(), message.tensors())); const char* data = payload->data(); size_t len = payload->length(); std::string* delete_when_done = payload.release(); enqueueRecv(RecvWork( getWorkerInfo(pg_->getRank()), message.type(), + message.id(), torch::from_blob( (void*)data, len, @@ -369,13 +314,15 @@ void ProcessGroupAgent::enqueueSend(SendWork work) { // NB: this can be changed to use a native move capture when moved to C++14 threadPool_.run(std::bind( [this](const SendWork& work) { - std::string serializedPayload = serialize(work.message_); + std::string serializedPayload = + wireSerialize(work.message_.payload(), work.message_.tensors()); std::vector preamble = {torch::tensor( {(int64_t)pg_->getRank(), (int64_t)serializedPayload.length(), - (int64_t)work.message_.type()}, - {torch::kLong})}; + (int64_t)work.message_.type(), + (int64_t)work.message_.id()}, + {torch::kInt64})}; // ProcessGroup is not thread-safe when sending with the same tag, hence // the lock @@ -416,8 +363,12 @@ void ProcessGroupAgent::enqueueRecv(RecvWork work) { threadPool_.run(std::bind( [&](RecvWork& work) { torch::Tensor& payload = work.payload_; - Message message = - deserialize(work.type_, payload.storage().data(), payload.numel()); + auto data = wireDeserialize(payload.storage().data(), payload.numel()); + Message message( + std::move(data.first), + std::move(data.second), + work.type_, + work.id_); if (message.isRequest()) { send(work.from_, cb_->operator()(message)); } else if (message.isResponse()) { @@ -467,7 +418,7 @@ void ProcessGroupAgent::enqueueRecv(RecvWork work) { void ProcessGroupAgent::listenLoop() { while (rpcRunning_.load()) { // rank, tensor size, message type - std::vector preamble = {torch::empty({3}, {torch::kInt64})}; + std::vector preamble = {torch::empty({4}, {torch::kInt64})}; auto work = pg_->recvAnysource(preamble, pg_->getRank()); { std::lock_guard guard(recvWorkMutex_); @@ -483,6 +434,7 @@ void ProcessGroupAgent::listenLoop() { auto srcRank = preamble_items[0]; auto size = preamble_items[1]; MessageType type = MessageType(preamble_items[2]); + int64_t id = preamble_items[3]; if (type == MessageType::SHUTDOWN) { // FIXME: This LOG also prints warnings no InitGoogleLogging() was invoked @@ -496,7 +448,8 @@ void ProcessGroupAgent::listenLoop() { std::vector tensors = {torch::empty({size}, {torch::kChar})}; pg_->recv(tensors, srcRank, pg_->getRank())->wait(); - enqueueRecv(RecvWork(allWorkerInfo_[srcRank], type, std::move(tensors[0]))); + enqueueRecv( + RecvWork(allWorkerInfo_[srcRank], type, id, std::move(tensors[0]))); } } diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h index 12df60157d04..7b271f4f671a 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.h +++ b/torch/csrc/distributed/rpc/process_group_agent.h @@ -31,11 +31,16 @@ struct SendWork { // SendWork wraps a Message and RecvWork wraps a Tensor. The difference here is // to allow us to run serialization/deserialization in the worker threads. struct RecvWork { - RecvWork(const WorkerInfo& from, MessageType type, torch::Tensor&& payload) - : from_(from), type_(type), payload_(payload) {} + RecvWork( + const WorkerInfo& from, + MessageType type, + int64_t id, + torch::Tensor&& payload) + : from_(from), type_(type), id_(id), payload_(payload) {} const WorkerInfo& from_; const MessageType type_; + const int64_t id_; torch::Tensor payload_; }; diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index 8ea4ce6f32f9..0bcbc5f8f917 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -11,6 +11,8 @@ #include #include #include +#include +#include namespace torch { namespace distributed { @@ -101,6 +103,184 @@ std::unique_ptr deserializeResponse(const Message& response) { } } +namespace { + +// Helper for wireDeserialize() below. +// +// The format we use below looks like: +// section_name_1 size_1\n +// section_name_2 size_2\n +// .. +// \n +// [sections in order] +// +// Sections themselves include: +// - "payload" - the payload bits +// - "meta" - metadata for the unpickler +// - "0" ... - tensor sections for the unpickler +// +// Note that per the header comments, the format is subject to change, +// and is best used for rpcs, rather than persistent disk storage. +std::unordered_map> +parseWireSections(const void* data, size_t data_size) { + const char* ptr = static_cast(data); + const char* endp = ptr + data_size; + + std::vector> headerEnts; + bool ok = false; + while (ptr != endp) { + if (*ptr == '\n') { + ok = true; // The only "correct" exit point. + ++ptr; + break; + } + // Parse name + const char* namePtr = ptr; + while (*ptr != ' ' && ptr != endp) { + ptr++; + } + if (ptr == endp) { + break; + } + std::string name(namePtr, ptr - namePtr); + if (++ptr == endp) { + break; // past the ' ' + } + // Parse size + const char* sizePtr = ptr; + while (*ptr != '\n' && ptr != endp) { + ptr++; + } + if (ptr == endp) { + break; + } + size_t sz = c10::stoll(std::string(sizePtr, ptr - sizePtr)); + headerEnts.emplace_back(std::make_pair(name, sz)); + ++ptr; // past the '\n' + } + if (!ok) { + throw std::runtime_error("failed parse"); + } + + std::unordered_map> out; + for (const auto& headerEnt : headerEnts) { + out[headerEnt.first] = {ptr, headerEnt.second}; + ptr += headerEnt.second; + } + if (ptr != endp) { + throw std::runtime_error("failed bounds"); + } + return out; +} + +static const char* kMeta = "meta"; +static const char* kPayload = "payload"; +}; // namespace + +std::string wireSerialize( + const std::vector& payload, + const std::vector& tensors) { + struct Ent { + std::string name; + const char* data; + size_t size; + }; + std::vector entries; + std::string metaEntry; + if (!payload.empty()) { + entries.push_back({kPayload, payload.data(), payload.size()}); + } + + if (!tensors.empty()) { + torch::jit::Pickler pickler( + [&](const void* buf, size_t sz) -> size_t { + metaEntry.append(static_cast(buf), sz); + return sz; + }, + nullptr); + pickler.protocol(); + pickler.pushIValue(tensors); + pickler.stop(); + auto writeable_tensors = pickler.tensorData(); + entries.push_back({kMeta, metaEntry.data(), metaEntry.size()}); + for (size_t i = 0; i < writeable_tensors.size(); i++) { + entries.push_back({c10::to_string(i), + writeable_tensors[i].data(), + writeable_tensors[i].sizeInBytes()}); + } + } + + std::string header; + size_t tot = 0; + for (const auto& e : entries) { + tot += e.size; + header.append(e.name) + .append(" ") + .append(c10::to_string(e.size)) + .append("\n"); + } + header.push_back('\n'); + + std::string out; + out.reserve(header.size() + tot); + out.append(header); + for (const auto& e : entries) { + out.append(e.data, e.size); + } + return out; +} + +std::pair, std::vector> wireDeserialize( + const void* data, + size_t data_size) { + auto sections = parseWireSections(data, data_size); + + std::vector payload; + auto payloadIt = sections.find(kPayload); + if (payloadIt != sections.end() && payloadIt->second.second != 0) { + payload.assign( + payloadIt->second.first, + payloadIt->second.first + payloadIt->second.second); + } + + std::vector tensors; + auto metaIt = sections.find(kMeta); + if (metaIt != sections.end()) { + const auto& metaData = metaIt->second; + size_t metaDataPos = 0; + auto metaDataReadFunc = [&](char* buf, size_t n) -> size_t { + if (metaDataPos >= metaData.second || n == 0) { + return 0; + } + size_t toCopy = + std::min(metaDataPos + n, metaData.second) - metaDataPos; + memcpy(buf, metaData.first + metaDataPos, toCopy); + metaDataPos += toCopy; + return toCopy; + }; + auto sectionReadFunc = [&](const std::string& ename) -> at::DataPtr { + auto it = sections.find(ename); + if (it == sections.end()) { + throw std::runtime_error("Couldn't find entity " + ename); + } + const auto& idat = it->second; + auto dptr = at::getCPUAllocator()->allocate(idat.second); + if (idat.second != 0) { + memcpy(dptr.get(), idat.first, idat.second); + } + return dptr; + }; + + torch::jit::Unpickler unpickler( + metaDataReadFunc, nullptr, nullptr, sectionReadFunc, {}); + auto ival = unpickler.parse_ivalue(); + for (auto&& t : ival.toTensorList()) { + tensors.emplace_back(std::move(t)); + } + } + return {std::move(payload), std::move(tensors)}; +} + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/utils.h b/torch/csrc/distributed/rpc/utils.h index d028bb3a7fa7..f34a4c53ed49 100644 --- a/torch/csrc/distributed/rpc/utils.h +++ b/torch/csrc/distributed/rpc/utils.h @@ -16,6 +16,16 @@ TORCH_API std::unique_ptr deserializeRequest( TORCH_API std::unique_ptr deserializeResponse( const Message& response); +// Note: format is subject to change and intended for RPCs. +// For saving persistently to disk, use torch::save(). +TORCH_API std::string wireSerialize( + const std::vector& payload, + const std::vector& tensors); + +TORCH_API std::pair, std::vector> wireDeserialize( + const void* data, + size_t data_size); + } // namespace rpc } // namespace distributed } // namespace torch