#include #include #include #include #include #include using namespace torch::nn; using namespace torch::test; struct ParameterListTest : torch::test::SeedingFixture {}; TEST_F(ParameterListTest, ConstructsFromSharedPointer) { torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); torch::Tensor tc = torch::randn({1, 2}); ASSERT_TRUE(ta.requires_grad()); ASSERT_FALSE(tb.requires_grad()); ParameterList list(ta, tb, tc); ASSERT_EQ(list->size(), 3); } TEST_F(ParameterListTest, isEmpty) { torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); ParameterList list; ASSERT_TRUE(list->is_empty()); list->append(ta); ASSERT_FALSE(list->is_empty()); ASSERT_EQ(list->size(), 1); } TEST_F(ParameterListTest, PushBackAddsAnElement) { ParameterList list; torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); torch::Tensor tc = torch::randn({1, 2}); torch::Tensor td = torch::randn({1, 2, 3}); ASSERT_EQ(list->size(), 0); ASSERT_TRUE(list->is_empty()); list->append(ta); ASSERT_EQ(list->size(), 1); list->append(tb); ASSERT_EQ(list->size(), 2); list->append(tc); ASSERT_EQ(list->size(), 3); list->append(td); ASSERT_EQ(list->size(), 4); } TEST_F(ParameterListTest, ForEachLoop) { torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); torch::Tensor tc = torch::randn({1, 2}); torch::Tensor td = torch::randn({1, 2, 3}); ParameterList list(ta, tb, tc, td); std::vector params = {ta, tb, tc, td}; ASSERT_EQ(list->size(), 4); int idx = 0; for (const auto& pair : *list) { ASSERT_TRUE( torch::all(torch::eq(pair.value(), params[idx++])).item()); } } TEST_F(ParameterListTest, AccessWithAt) { torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); torch::Tensor tc = torch::randn({1, 2}); torch::Tensor td = torch::randn({1, 2, 3}); std::vector params = {ta, tb, tc, td}; ParameterList list; for (auto& param : params) { list->append(param); } ASSERT_EQ(list->size(), 4); // returns the correct module for a given index for (size_t i = 0; i < params.size(); ++i) { ASSERT_TRUE(torch::all(torch::eq(list->at(i), params[i])).item()); } for (size_t i = 0; i < params.size(); ++i) { ASSERT_TRUE(torch::all(torch::eq(list[i], params[i])).item()); } // throws for a bad index ASSERT_THROWS_WITH(list->at(params.size() + 100), "Index out of range"); ASSERT_THROWS_WITH(list->at(params.size() + 1), "Index out of range"); ASSERT_THROWS_WITH(list[params.size() + 1], "Index out of range"); } TEST_F(ParameterListTest, ExtendPushesParametersFromOtherParameterList) { torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); torch::Tensor tc = torch::randn({1, 2}); torch::Tensor td = torch::randn({1, 2, 3}); torch::Tensor te = torch::randn({1, 2}); torch::Tensor tf = torch::randn({1, 2, 3}); ParameterList a(ta, tb); ParameterList b(tc, td); a->extend(*b); ASSERT_EQ(a->size(), 4); ASSERT_TRUE(torch::all(torch::eq(a[0], ta)).item()); ASSERT_TRUE(torch::all(torch::eq(a[1], tb)).item()); ASSERT_TRUE(torch::all(torch::eq(a[2], tc)).item()); ASSERT_TRUE(torch::all(torch::eq(a[3], td)).item()); ASSERT_EQ(b->size(), 2); ASSERT_TRUE(torch::all(torch::eq(b[0], tc)).item()); ASSERT_TRUE(torch::all(torch::eq(b[1], td)).item()); std::vector c = {te, tf}; b->extend(c); ASSERT_EQ(b->size(), 4); ASSERT_TRUE(torch::all(torch::eq(b[0], tc)).item()); ASSERT_TRUE(torch::all(torch::eq(b[1], td)).item()); ASSERT_TRUE(torch::all(torch::eq(b[2], te)).item()); ASSERT_TRUE(torch::all(torch::eq(b[3], tf)).item()); } TEST_F(ParameterListTest, PrettyPrintParameterList) { torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); torch::Tensor tc = torch::randn({1, 2}); ParameterList list(ta, tb, tc); ASSERT_EQ( c10::str(list), "torch::nn::ParameterList(\n" "(0): Parameter containing: [Float of size [1, 2]]\n" "(1): Parameter containing: [Float of size [1, 2]]\n" "(2): Parameter containing: [Float of size [1, 2]]\n" ")"); } TEST_F(ParameterListTest, IncrementAdd) { torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); torch::Tensor tc = torch::randn({1, 2}); torch::Tensor td = torch::randn({1, 2, 3}); torch::Tensor te = torch::randn({1, 2}); torch::Tensor tf = torch::randn({1, 2, 3}); ParameterList listA(ta, tb, tc); ParameterList listB(td, te, tf); std::vector tensors{ta, tb, tc, td, te, tf}; int idx = 0; *listA += *listB; ASSERT_TRUE(torch::all(torch::eq(listA[0], ta)).item()); ASSERT_TRUE(torch::all(torch::eq(listA[1], tb)).item()); ASSERT_TRUE(torch::all(torch::eq(listA[2], tc)).item()); ASSERT_TRUE(torch::all(torch::eq(listA[3], td)).item()); ASSERT_TRUE(torch::all(torch::eq(listA[4], te)).item()); ASSERT_TRUE(torch::all(torch::eq(listA[5], tf)).item()); for (const auto& P : listA->named_parameters(false)) ASSERT_TRUE(torch::all(torch::eq(P.value(), tensors[idx++])).item()); ASSERT_EQ(idx, 6); }