mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 06:34:55 +08:00
Summary: Here is a PR adding ```ModuleList``` to ```modules.h``` so that it can be used by including ```torch/torch.h```. yf225 edit: Closes https://github.com/pytorch/pytorch/issues/25293. Pull Request resolved: https://github.com/pytorch/pytorch/pull/25346 Differential Revision: D17115013 Pulled By: yf225 fbshipit-source-id: 38a1848b9a8272fa411865dfc83b76d10c5789a0
289 lines
8.3 KiB
C++
289 lines
8.3 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <torch/torch.h>
|
|
|
|
#include <algorithm>
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
using namespace torch::nn;
|
|
using namespace torch::test;
|
|
|
|
struct ModuleListTest : torch::test::SeedingFixture {};
|
|
|
|
TEST_F(ModuleListTest, ConstructsFromSharedPointer) {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int value;
|
|
};
|
|
ModuleList list(
|
|
std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
|
|
ASSERT_EQ(list->size(), 3);
|
|
}
|
|
|
|
TEST_F(ModuleListTest, ConstructsFromConcreteType) {
|
|
static int copy_count;
|
|
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
M(const M& other) : torch::nn::Module(other) {
|
|
copy_count++;
|
|
}
|
|
int value;
|
|
};
|
|
|
|
copy_count = 0;
|
|
ModuleList list(M(1), M(2), M(3));
|
|
ASSERT_EQ(list->size(), 3);
|
|
// NOTE: The current implementation expects each module to be copied exactly
|
|
// once, which happens when the module is passed into `std::make_shared<T>()`.
|
|
// TODO: Find a way to avoid copying, and then delete the copy constructor of
|
|
// `M`.
|
|
ASSERT_EQ(copy_count, 3);
|
|
}
|
|
|
|
TEST_F(ModuleListTest, ConstructsFromModuleHolder) {
|
|
struct MImpl : torch::nn::Module {
|
|
explicit MImpl(int value_) : value(value_) {}
|
|
int value;
|
|
};
|
|
|
|
struct M : torch::nn::ModuleHolder<MImpl> {
|
|
using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
|
|
using torch::nn::ModuleHolder<MImpl>::get;
|
|
};
|
|
|
|
ModuleList list(M(1), M(2), M(3));
|
|
ASSERT_EQ(list->size(), 3);
|
|
}
|
|
|
|
TEST_F(ModuleListTest, PushBackAddsAnElement) {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int value;
|
|
};
|
|
|
|
ModuleList list;
|
|
ASSERT_EQ(list->size(), 0);
|
|
ASSERT_TRUE(list->is_empty());
|
|
list->push_back(Linear(3, 4));
|
|
ASSERT_EQ(list->size(), 1);
|
|
list->push_back(std::make_shared<M>(1));
|
|
ASSERT_EQ(list->size(), 2);
|
|
list->push_back(M(2));
|
|
ASSERT_EQ(list->size(), 3);
|
|
}
|
|
|
|
TEST_F(ModuleListTest, Insertion) {
|
|
struct MImpl : torch::nn::Module {
|
|
explicit MImpl(int value_) : value(value_) {}
|
|
int value;
|
|
};
|
|
TORCH_MODULE(M);
|
|
|
|
ModuleList list;
|
|
list->push_back(MImpl(1));
|
|
ASSERT_EQ(list->size(), 1);
|
|
list->insert(0, std::make_shared<MImpl>(2));
|
|
ASSERT_EQ(list->size(), 2);
|
|
list->insert(1, M(3));
|
|
ASSERT_EQ(list->size(), 3);
|
|
list->insert(3, M(4));
|
|
ASSERT_EQ(list->size(), 4);
|
|
ASSERT_EQ(list->at<MImpl>(0).value, 2);
|
|
ASSERT_EQ(list->at<MImpl>(1).value, 3);
|
|
ASSERT_EQ(list->at<MImpl>(2).value, 1);
|
|
ASSERT_EQ(list->at<MImpl>(3).value, 4);
|
|
|
|
std::unordered_map<size_t, size_t> U = {{0, 2}, {1, 3}, {2, 1}, {3, 4}};
|
|
for (const auto& P : list->named_modules("", false))
|
|
ASSERT_EQ(U[std::stoul(P.key())], P.value()->as<M>()->value);
|
|
}
|
|
|
|
TEST_F(ModuleListTest, AccessWithAt) {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int value;
|
|
};
|
|
std::vector<std::shared_ptr<M>> modules = {
|
|
std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
|
|
|
|
ModuleList list;
|
|
for (auto& module : modules) {
|
|
list->push_back(module);
|
|
}
|
|
ASSERT_EQ(list->size(), 3);
|
|
|
|
// returns the correct module for a given index
|
|
for (size_t i = 0; i < modules.size(); ++i) {
|
|
ASSERT_EQ(&list->at<M>(i), modules[i].get());
|
|
}
|
|
|
|
// throws for a bad index
|
|
ASSERT_THROWS_WITH(list->at<M>(modules.size() + 1), "Index out of range");
|
|
ASSERT_THROWS_WITH(
|
|
list->at<M>(modules.size() + 1000000), "Index out of range");
|
|
}
|
|
|
|
TEST_F(ModuleListTest, AccessWithPtr) {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int value;
|
|
};
|
|
std::vector<std::shared_ptr<M>> modules = {
|
|
std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
|
|
|
|
ModuleList list;
|
|
for (auto& module : modules) {
|
|
list->push_back(module);
|
|
}
|
|
ASSERT_EQ(list->size(), 3);
|
|
|
|
// returns the correct module for a given index
|
|
for (size_t i = 0; i < modules.size(); ++i) {
|
|
ASSERT_EQ(list->ptr(i).get(), modules[i].get());
|
|
ASSERT_EQ(list[i].get(), modules[i].get());
|
|
ASSERT_EQ(list->ptr<M>(i).get(), modules[i].get());
|
|
}
|
|
|
|
// throws for a bad index
|
|
ASSERT_THROWS_WITH(list->ptr(modules.size() + 1), "Index out of range");
|
|
ASSERT_THROWS_WITH(list->ptr(modules.size() + 1000000), "Index out of range");
|
|
}
|
|
|
|
TEST_F(ModuleListTest, SanityCheckForHoldingStandardModules) {
|
|
ModuleList list(
|
|
Linear(10, 3),
|
|
Conv2d(1, 2, 3),
|
|
Dropout(0.5),
|
|
BatchNorm(5),
|
|
Embedding(4, 10),
|
|
LSTM(4, 5));
|
|
}
|
|
|
|
TEST_F(ModuleListTest, ExtendPushesModulesFromOtherModuleList) {
|
|
struct A : torch::nn::Module {};
|
|
struct B : torch::nn::Module {};
|
|
struct C : torch::nn::Module {};
|
|
struct D : torch::nn::Module {};
|
|
ModuleList a(A{}, B{});
|
|
ModuleList b(C{}, D{});
|
|
a->extend(*b);
|
|
|
|
ASSERT_EQ(a->size(), 4);
|
|
ASSERT_TRUE(a[0]->as<A>());
|
|
ASSERT_TRUE(a[1]->as<B>());
|
|
ASSERT_TRUE(a[2]->as<C>());
|
|
ASSERT_TRUE(a[3]->as<D>());
|
|
|
|
ASSERT_EQ(b->size(), 2);
|
|
ASSERT_TRUE(b[0]->as<C>());
|
|
ASSERT_TRUE(b[1]->as<D>());
|
|
|
|
std::vector<std::shared_ptr<A>> c = {std::make_shared<A>(),
|
|
std::make_shared<A>()};
|
|
b->extend(c);
|
|
|
|
ASSERT_EQ(b->size(), 4);
|
|
ASSERT_TRUE(b[0]->as<C>());
|
|
ASSERT_TRUE(b[1]->as<D>());
|
|
ASSERT_TRUE(b[2]->as<A>());
|
|
ASSERT_TRUE(b[3]->as<A>());
|
|
}
|
|
|
|
TEST_F(ModuleListTest, HasReferenceSemantics) {
|
|
ModuleList first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
|
|
ModuleList second(first);
|
|
|
|
ASSERT_EQ(first.get(), second.get());
|
|
ASSERT_EQ(first->size(), second->size());
|
|
ASSERT_TRUE(std::equal(
|
|
first->begin(),
|
|
first->end(),
|
|
second->begin(),
|
|
[](const std::shared_ptr<Module>& first,
|
|
const std::shared_ptr<Module>& second) {
|
|
return first.get() == second.get();
|
|
}));
|
|
}
|
|
|
|
TEST_F(ModuleListTest, IsCloneable) {
|
|
ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
|
|
ModuleList clone = std::dynamic_pointer_cast<ModuleListImpl>(list->clone());
|
|
ASSERT_EQ(list->size(), clone->size());
|
|
|
|
for (size_t i = 0; i < list->size(); ++i) {
|
|
// The modules should be the same kind (type).
|
|
ASSERT_EQ(list[i]->name(), clone[i]->name());
|
|
// But not pointer-equal (distinct objects).
|
|
ASSERT_NE(list[i], clone[i]);
|
|
}
|
|
|
|
// Verify that the clone is deep, i.e. parameters of modules are cloned too.
|
|
|
|
torch::NoGradGuard no_grad;
|
|
|
|
auto params1 = list->named_parameters();
|
|
auto params2 = clone->named_parameters();
|
|
ASSERT_EQ(params1.size(), params2.size());
|
|
for (auto& param : params1) {
|
|
ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
|
|
ASSERT_EQ(param->device(), params2[param.key()].device());
|
|
ASSERT_TRUE(param->allclose(params2[param.key()]));
|
|
param->add_(2);
|
|
}
|
|
for (auto& param : params1) {
|
|
ASSERT_FALSE(param->allclose(params2[param.key()]));
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleListTest, RegistersElementsAsSubmodules) {
|
|
ModuleList list(Linear(10, 3), Conv2d(1, 2, 3), FeatureDropout(0.5));
|
|
|
|
auto modules = list->children();
|
|
ASSERT_TRUE(modules[0]->as<Linear>());
|
|
ASSERT_TRUE(modules[1]->as<Conv2d>());
|
|
ASSERT_TRUE(modules[2]->as<FeatureDropout>());
|
|
}
|
|
|
|
TEST_F(ModuleListTest, NestingIsPossible) {
|
|
ModuleList list(
|
|
(ModuleList(Dropout(), Dropout())),
|
|
(ModuleList(Dropout(), Dropout()), Dropout()));
|
|
}
|
|
|
|
TEST_F(ModuleListTest, CloneToDevice_CUDA) {
|
|
ModuleList list(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
|
|
torch::Device device(torch::kCUDA, 0);
|
|
ModuleList clone =
|
|
std::dynamic_pointer_cast<ModuleListImpl>(list->clone(device));
|
|
for (const auto& p : clone->parameters()) {
|
|
ASSERT_EQ(p.device(), device);
|
|
}
|
|
for (const auto& b : clone->buffers()) {
|
|
ASSERT_EQ(b.device(), device);
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleListTest, PrettyPrintModuleList) {
|
|
ModuleList list(
|
|
Linear(10, 3),
|
|
Conv2d(1, 2, 3),
|
|
Dropout(0.5),
|
|
BatchNorm(5),
|
|
Embedding(4, 10),
|
|
LSTM(4, 5));
|
|
ASSERT_EQ(
|
|
c10::str(list),
|
|
"torch::nn::ModuleList(\n"
|
|
" (0): torch::nn::Linear(in=10, out=3, with_bias=true)\n"
|
|
" (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
|
|
" (2): torch::nn::Dropout(rate=0.5)\n"
|
|
" (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
|
|
" (4): torch::nn::Embedding(count=4, dimension=10)\n"
|
|
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
|
|
")");
|
|
}
|