mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66744 Modified loops in files under fbsource/fbcode/caffe2/ from the format `for(TYPE var=x0;var<x_max;x++)` to the format `for(const auto var: irange(xmax))` This was achieved by running r-barnes's loop upgrader script (D28874212) with some modification to exclude all files under /torch/jit and a number of reversions or unused variable suppression warnings added by hand. Test Plan: Sandcastle Reviewed By: ngimel Differential Revision: D31705358 fbshipit-source-id: d6ea350cbaa8f452fc78f238160e5374be637a48
106 lines
3.7 KiB
C++
106 lines
3.7 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/distributed/rpc/utils.h>
|
|
#include <torch/torch.h>
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
using ::testing::IsSubstring;
|
|
|
|
TEST(WireSerialize, Base) {
|
|
auto run = [](const std::string& payload,
|
|
const std::vector<at::Tensor>& tensors) {
|
|
std::string serialized;
|
|
{
|
|
std::vector<char> mpayload(payload.begin(), payload.end());
|
|
std::vector<at::Tensor> mtensors = tensors;
|
|
serialized = torch::distributed::rpc::wireSerialize(
|
|
std::move(mpayload), std::move(mtensors));
|
|
}
|
|
auto deser = torch::distributed::rpc::wireDeserialize(
|
|
serialized.data(), serialized.size());
|
|
EXPECT_EQ(payload.size(), deser.first.size());
|
|
EXPECT_EQ(tensors.size(), deser.second.size());
|
|
if (payload.size() > 0) {
|
|
EXPECT_TRUE(
|
|
memcmp(deser.first.data(), payload.data(), payload.size()) == 0);
|
|
}
|
|
for (const auto i : c10::irange(tensors.size())) {
|
|
EXPECT_TRUE(torch::equal(tensors[i], deser.second[i]));
|
|
}
|
|
};
|
|
run("", {});
|
|
run("hi", {});
|
|
run("", {torch::randn({5, 5})});
|
|
run("hi", {torch::randn({5, 5})});
|
|
run("more", {torch::randn({5, 5}), torch::rand({10, 10})});
|
|
}
|
|
|
|
TEST(WireSerialize, RecopySparseTensors) {
|
|
// Take a 1K row of a 1M tensors, and make sure we don't send across 1M rows.
|
|
constexpr size_t k1K = 1024;
|
|
at::Tensor main = torch::randn({k1K, k1K});
|
|
at::Tensor tiny = main.select(0, 2); // Select a row in the middle
|
|
EXPECT_EQ(tiny.numel(), k1K);
|
|
EXPECT_EQ(tiny.storage().nbytes() / tiny.dtype().itemsize(), k1K * k1K);
|
|
auto ser = torch::distributed::rpc::wireSerialize({}, {tiny});
|
|
auto deser = torch::distributed::rpc::wireDeserialize(ser.data(), ser.size());
|
|
EXPECT_TRUE(torch::equal(tiny, deser.second[0]));
|
|
EXPECT_LT(ser.size(), (tiny.element_size() * k1K) + k1K);
|
|
}
|
|
|
|
TEST(WireSerialize, CloneSparseTensors) {
|
|
constexpr size_t k1K = 1024;
|
|
at::Tensor big = torch::randn({k1K, k1K});
|
|
auto v1 = torch::distributed::rpc::cloneSparseTensors({big});
|
|
EXPECT_EQ(v1.get(0).storage(), big.storage()); // Not cloned
|
|
|
|
at::Tensor tiny = big.select(0, 2); // Select a row in the middle
|
|
auto v2 = torch::distributed::rpc::cloneSparseTensors({tiny});
|
|
EXPECT_NE(&v2.get(0).storage(), &tiny.storage()); // Cloned.
|
|
EXPECT_TRUE(torch::equal(v2.get(0), tiny));
|
|
|
|
at::Tensor sparse = at::empty({2, 3}, at::dtype<float>().layout(at::kSparse));
|
|
auto v3 = torch::distributed::rpc::cloneSparseTensors({sparse});
|
|
// There is no storage() to compare, but at least confirm equality.
|
|
EXPECT_TRUE(v3.get(0).is_same(sparse));
|
|
}
|
|
|
|
TEST(WireSerialize, Errors) {
|
|
auto checkMessage = [](auto&& f, const char* msg) {
|
|
try {
|
|
f();
|
|
FAIL();
|
|
} catch (const std::exception& e) {
|
|
EXPECT_PRED_FORMAT2(IsSubstring, msg, e.what());
|
|
} catch (...) {
|
|
FAIL();
|
|
}
|
|
};
|
|
checkMessage(
|
|
[]() { (void)torch::distributed::rpc::wireDeserialize("", 0); },
|
|
"failed parse");
|
|
checkMessage(
|
|
[]() { (void)torch::distributed::rpc::wireDeserialize(" ", 1); },
|
|
"failed parse");
|
|
auto serialized =
|
|
torch::distributed::rpc::wireSerialize({}, {torch::randn({5, 5})});
|
|
checkMessage(
|
|
[&]() {
|
|
(void)torch::distributed::rpc::wireDeserialize(
|
|
serialized.data(), serialized.size() / 2);
|
|
},
|
|
"failed bounds");
|
|
}
|
|
|
|
// Enable this once JIT Pickler supports sparse tensors.
|
|
TEST(WireSerialize, DISABLED_Sparse) {
|
|
at::Tensor main = at::empty({2, 3}, at::dtype<float>().layout(at::kSparse));
|
|
auto ser = torch::distributed::rpc::wireSerialize({}, {main.to(at::kSparse)});
|
|
auto deser = torch::distributed::rpc::wireDeserialize(ser.data(), ser.size());
|
|
EXPECT_TRUE(torch::equal(main, deser.second[0]));
|
|
}
|