Avoid sending large unneeded data over wire in process_group_agent. (#31357)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31357

If a user selects a subset of a Tensor and sends it in an RPC, we were sending
the whole original Tensor Storage over the network.

While this sounds reasonable, in practice, we observed view-like Tensors being sent
over rpc, where only 1% of the data in the provided Tensor's Storage was
actually used/needed.

The simple solution here is to just force a clone in the serializer code if we see that
less than (arbitrary) half the bits are used, and the tensor is more than a nominal few KB.
Add related tests to ensure this doesn't break.

An alternate approach would be to modify the Pickler. That said, since Pickler is shared by more
components, the logic might be harder to tailor appropriately at that layer (particularly
given that the Pickler has explicit logic to share a single Storage* among several Tensors
that commonly point to the same Storage*).

It's possible that we might want to further refine the basic thresholds in this change.
In practice, we've seen a mostly bimodal distribution thus far for the percent of Tensor
Storage referred by a Tensor in observed rpcs (i.e. either 90%+ or sub-10% of the Storage
referenced), hence the existing 50% threshold here is probably not an unreasonable
starting point.
ghstack-source-id: 95925474

Test Plan: buck test mode/dev caffe2/test/cpp/rpc/...

Differential Revision: D19137056

fbshipit-source-id: e2b3a4dd0cc6e1de820fd0740aa1d59883dbf8d4
This commit is contained in:
Jeremy Lilley
2019-12-18 19:21:58 -08:00
committed by Facebook Github Bot
parent 1bb800cf5c
commit dff7b945bf
3 changed files with 43 additions and 1 deletions

View File

@ -37,3 +37,16 @@ TEST(WireSerialize, Base) {
run("hi", {torch::randn({5, 5})});
run("more", {torch::randn({5, 5}), torch::rand({10, 10})});
}
TEST(WireSerialize, RecopySparseTensors) {
// Take a 1K row of a 1M tensors, and make sure we don't send across 1M rows.
constexpr size_t k1K = 1024;
at::Tensor main = torch::randn({k1K, k1K});
at::Tensor tiny = main.select(0, 2); // Select a row in the middle
EXPECT_EQ(tiny.numel(), k1K);
EXPECT_EQ(tiny.storage().numel(), k1K * k1K);
auto ser = torch::distributed::rpc::wireSerialize({}, {tiny});
auto deser = torch::distributed::rpc::wireDeserialize(ser.data(), ser.size());
EXPECT_TRUE(torch::equal(tiny, deser.second[0]));
EXPECT_LT(ser.size(), (tiny.element_size() * k1K) + k1K);
}