mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 15:35:04 +08:00
Avoid sending large unneeded data over wire in process_group_agent. (#31357)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31357 If a user selects a subset of a Tensor and sends it in an RPC, we were sending the whole original Tensor Storage over the network. While this sounds reasonable, in practice, we observed view-like Tensors being sent over rpc, where only 1% of the data in the provided Tensor's Storage was actually used/needed. The simple solution here is to just force a clone in the serializer code if we see that less than (arbitrary) half the bits are used, and the tensor is more than a nominal few KB. Add related tests to ensure this doesn't break. An alternate approach would be to modify the Pickler. That said, since Pickler is shared by more components, the logic might be harder to tailor appropriately at that layer (particularly given that the Pickler has explicit logic to share a single Storage* among several Tensors that commonly point to the same Storage*). It's possible that we might want to further refine the basic thresholds in this change. In practice, we've seen a mostly bimodal distribution thus far for the percent of Tensor Storage referred by a Tensor in observed rpcs (i.e. either 90%+ or sub-10% of the Storage referenced), hence the existing 50% threshold here is probably not an unreasonable starting point. ghstack-source-id: 95925474 Test Plan: buck test mode/dev caffe2/test/cpp/rpc/... Differential Revision: D19137056 fbshipit-source-id: e2b3a4dd0cc6e1de820fd0740aa1d59883dbf8d4
This commit is contained in:
committed by
Facebook Github Bot
parent
1bb800cf5c
commit
dff7b945bf
@ -37,3 +37,16 @@ TEST(WireSerialize, Base) {
|
||||
run("hi", {torch::randn({5, 5})});
|
||||
run("more", {torch::randn({5, 5}), torch::rand({10, 10})});
|
||||
}
|
||||
|
||||
TEST(WireSerialize, RecopySparseTensors) {
|
||||
// Take a 1K row of a 1M tensors, and make sure we don't send across 1M rows.
|
||||
constexpr size_t k1K = 1024;
|
||||
at::Tensor main = torch::randn({k1K, k1K});
|
||||
at::Tensor tiny = main.select(0, 2); // Select a row in the middle
|
||||
EXPECT_EQ(tiny.numel(), k1K);
|
||||
EXPECT_EQ(tiny.storage().numel(), k1K * k1K);
|
||||
auto ser = torch::distributed::rpc::wireSerialize({}, {tiny});
|
||||
auto deser = torch::distributed::rpc::wireDeserialize(ser.data(), ser.size());
|
||||
EXPECT_TRUE(torch::equal(tiny, deser.second[0]));
|
||||
EXPECT_LT(ser.size(), (tiny.element_size() * k1K) + k1K);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user