mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes https://github.com/pytorch/pytorch/issues/73565 Pull Request resolved: https://github.com/pytorch/pytorch/pull/93074 Approved by: https://github.com/albanD
310 lines
9.9 KiB
C++
310 lines
9.9 KiB
C++
#include <gtest/gtest.h>
|
|
#include <torch/torch.h>
|
|
#include <algorithm>
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
using namespace torch::nn;
|
|
using namespace torch::test;
|
|
|
|
struct ModuleDictTest : torch::test::SeedingFixture {};
|
|
|
|
TEST_F(ModuleDictTest, ConstructsFromList) {
|
|
struct M : Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int value;
|
|
};
|
|
|
|
std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = {
|
|
{"module_1", std::make_shared<M>(1)},
|
|
{"module_2", std::make_shared<M>(2)},
|
|
{"module_3", std::make_shared<M>(3)}};
|
|
ModuleDict dict(list);
|
|
ASSERT_EQ(dict->size(), 3);
|
|
}
|
|
|
|
TEST_F(ModuleDictTest, ConstructsFromordereddict) {
|
|
struct M : Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int value;
|
|
};
|
|
|
|
torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
|
|
{"module_1", std::make_shared<M>(1)},
|
|
{"module_2", std::make_shared<M>(2)},
|
|
{"module_3", std::make_shared<M>(3)},
|
|
};
|
|
ModuleDict dict(ordereddict);
|
|
ASSERT_EQ(dict->size(), 3);
|
|
}
|
|
|
|
TEST_F(ModuleDictTest, UpdatePopClearContains) {
|
|
struct M : Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int value;
|
|
};
|
|
|
|
ModuleDict dict;
|
|
ASSERT_TRUE(dict->empty());
|
|
// Update by List
|
|
std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = {
|
|
{"module_1", std::make_shared<M>(1)}};
|
|
dict->update(list1);
|
|
ASSERT_EQ(dict->size(), 1);
|
|
ASSERT_TRUE(dict->contains("module_1"));
|
|
// Update by OrderedDict
|
|
torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
|
|
{"module_2", std::make_shared<M>(2)}};
|
|
dict->update(ordereddict);
|
|
ASSERT_EQ(dict->size(), 2);
|
|
ASSERT_TRUE(dict->contains("module_2"));
|
|
// Update by another ModuleDict
|
|
std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = {
|
|
{"module_3", std::make_shared<M>(3)}};
|
|
ModuleDict updatedict(list2);
|
|
dict->update(*updatedict);
|
|
ASSERT_EQ(dict->size(), 3);
|
|
ASSERT_TRUE(dict->contains("module_3"));
|
|
// Pop
|
|
dict->pop("module_1");
|
|
ASSERT_EQ(dict->size(), 2);
|
|
// Pop unexist
|
|
ASSERT_THROWS_WITH(dict->pop("module_4"), " 'module_4' is not defined");
|
|
// Clear
|
|
dict->clear();
|
|
ASSERT_EQ(dict->size(), 0);
|
|
}
|
|
|
|
TEST_F(ModuleDictTest, UpdateExist) {
|
|
struct M : Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int value;
|
|
};
|
|
std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = {
|
|
{"module_1", std::make_shared<M>(1)},
|
|
{"module_2", std::make_shared<M>(2)}};
|
|
ModuleDict dict(list1);
|
|
ASSERT_EQ(dict->at<M>("module_2").value, 2);
|
|
// Update by list
|
|
std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = {
|
|
{"module_2", std::make_shared<M>(0)},
|
|
{"module_3", std::make_shared<M>(3)}};
|
|
dict->update(list2);
|
|
ASSERT_EQ(dict->size(), 3);
|
|
ASSERT_EQ(dict->at<M>("module_2").value, 0);
|
|
// Update by ordereddict
|
|
torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
|
|
{"module_3", std::make_shared<M>(0)},
|
|
{"module_4", std::make_shared<M>(4)}};
|
|
dict->update(ordereddict);
|
|
ASSERT_EQ(dict->size(), 4);
|
|
ASSERT_EQ(dict->at<M>("module_3").value, 0);
|
|
// Update by ModuleDict
|
|
std::vector<std::pair<std::string, std::shared_ptr<Module>>> list3 = {
|
|
{"module_4", std::make_shared<M>(0)},
|
|
{"module_1", std::make_shared<M>(0)}};
|
|
ModuleDict dict2(list3);
|
|
dict->update(*dict2);
|
|
ASSERT_EQ(dict->size(), 4);
|
|
ASSERT_EQ(dict->at<M>("module_1").value, 0);
|
|
ASSERT_EQ(dict->at<M>("module_4").value, 0);
|
|
}
|
|
|
|
TEST_F(ModuleDictTest, Keys) {
|
|
struct M : Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int value;
|
|
};
|
|
|
|
torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
|
|
{"linear", Linear(10, 3).ptr()},
|
|
{"conv", Conv2d(1, 2, 3).ptr()},
|
|
{"dropout", Dropout(0.5).ptr()},
|
|
};
|
|
ModuleDict dict(ordereddict);
|
|
const auto& keys = dict->keys();
|
|
std::vector<std::string> expected{"linear", "conv", "dropout"};
|
|
ASSERT_EQ(keys, expected);
|
|
ASSERT_THROWS_WITH(dict["batch"], " 'batch' is not defined");
|
|
|
|
ASSERT_TRUE(dict["linear"]->as<Linear>());
|
|
ASSERT_TRUE(dict["conv"]->as<Conv2d>());
|
|
ASSERT_TRUE(dict["dropout"]->as<Dropout>());
|
|
}
|
|
|
|
TEST_F(ModuleDictTest, Values) {
|
|
struct M : Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int value;
|
|
};
|
|
|
|
torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
|
|
{"module_1", std::make_shared<M>(1)},
|
|
{"module_2", std::make_shared<M>(2)},
|
|
};
|
|
ModuleDict dict(ordereddict);
|
|
const auto& values = dict->values();
|
|
const auto& expected = ordereddict.values();
|
|
ASSERT_EQ(values, expected);
|
|
ASSERT_TRUE(std::equal(
|
|
dict->begin(),
|
|
dict->end(),
|
|
ordereddict.begin(),
|
|
[](const auto& lhs, const auto& rhs) {
|
|
return lhs.value().get() == rhs.value().get();
|
|
}));
|
|
}
|
|
|
|
TEST_F(ModuleDictTest, SanityCheckForHoldingStandardModules) {
|
|
torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
|
|
{"linear", Linear(10, 3).ptr()},
|
|
{"conv", Conv2d(1, 2, 3).ptr()},
|
|
{"dropout", Dropout(0.5).ptr()},
|
|
{"batch", BatchNorm2d(5).ptr()},
|
|
{"embedding", Embedding(4, 10).ptr()},
|
|
{"lstm", LSTM(4, 5).ptr()}};
|
|
ModuleDict dict(ordereddict);
|
|
}
|
|
|
|
TEST_F(ModuleDictTest, HasReferenceSemantics) {
|
|
torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
|
|
{"linear1", Linear(2, 3).ptr()},
|
|
{"linear2", Linear(3, 4).ptr()},
|
|
{"linear3", Linear(4, 5).ptr()},
|
|
};
|
|
ModuleDict first(ordereddict);
|
|
ModuleDict second(ordereddict);
|
|
|
|
ASSERT_EQ(first->size(), second->size());
|
|
ASSERT_TRUE(std::equal(
|
|
first->begin(),
|
|
first->end(),
|
|
second->begin(),
|
|
[](const auto& lhs, const auto& rhs) {
|
|
return lhs.value().get() == rhs.value().get();
|
|
}));
|
|
}
|
|
|
|
void iscloneable_helper(torch::Device device) {
|
|
torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
|
|
{"linear", Linear(2, 3).ptr()},
|
|
{"relu", Functional(torch::relu).ptr()},
|
|
{"batch", BatchNorm1d(3).ptr()},
|
|
};
|
|
ModuleDict dict(ordereddict);
|
|
dict->to(device);
|
|
ModuleDict clone =
|
|
std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device));
|
|
ASSERT_EQ(dict->size(), clone->size());
|
|
|
|
for (auto it = dict->begin(), it_c = clone->begin(); it != dict->end();
|
|
++it, ++it_c) {
|
|
// The key should be same
|
|
ASSERT_EQ(it->key(), it_c->key());
|
|
// The modules should be the same kind (type).
|
|
ASSERT_EQ(it->value()->name(), it_c->value()->name());
|
|
// But not pointer-equal (distinct objects).
|
|
ASSERT_NE(it->value(), it_c->value());
|
|
}
|
|
|
|
// Verify that the clone is deep, i.e. parameters of modules are cloned too.
|
|
torch::NoGradGuard no_grad;
|
|
|
|
auto params1 = dict->named_parameters();
|
|
auto params2 = clone->named_parameters();
|
|
ASSERT_EQ(params1.size(), params2.size());
|
|
for (auto& param : params1) {
|
|
ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
|
|
ASSERT_EQ(param->device(), params2[param.key()].device());
|
|
ASSERT_TRUE(param->allclose(params2[param.key()]));
|
|
param->add_(2);
|
|
}
|
|
for (auto& param : params1) {
|
|
ASSERT_FALSE(param->allclose(params2[param.key()]));
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleDictTest, IsCloneable) {
|
|
iscloneable_helper(torch::kCPU);
|
|
}
|
|
|
|
TEST_F(ModuleDictTest, IsCloneable_CUDA) {
|
|
iscloneable_helper({torch::kCUDA, 0});
|
|
}
|
|
|
|
TEST_F(ModuleDictTest, RegistersElementsAsSubmodules) {
|
|
torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict1 = {
|
|
{"linear", Linear(10, 3).ptr()},
|
|
{"conv", Conv2d(1, 2, 3).ptr()},
|
|
{"test", Dropout(0.5).ptr()},
|
|
};
|
|
ModuleDict dict(ordereddict1);
|
|
|
|
auto modules = dict->children();
|
|
ASSERT_TRUE(modules[0]->as<Linear>());
|
|
ASSERT_TRUE(modules[1]->as<Conv2d>());
|
|
ASSERT_TRUE(modules[2]->as<Dropout>());
|
|
|
|
// Update Existing
|
|
torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict2 = {
|
|
{"lstm", LSTM(4, 5).ptr()}, {"test", BatchNorm2d(5).ptr()}};
|
|
dict->update(ordereddict2);
|
|
|
|
modules = dict->children();
|
|
ASSERT_TRUE(modules[0]->as<Linear>());
|
|
ASSERT_TRUE(modules[1]->as<Conv2d>());
|
|
// Keep Order
|
|
ASSERT_TRUE(modules[2]->as<BatchNorm2d>());
|
|
ASSERT_TRUE(modules[3]->as<LSTM>());
|
|
}
|
|
|
|
TEST_F(ModuleDictTest, CloneToDevice_CUDA) {
|
|
torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
|
|
{"linear", Linear(2, 3).ptr()},
|
|
{"relu", Functional(torch::relu).ptr()},
|
|
{"batch", BatchNorm1d(3).ptr()},
|
|
};
|
|
ModuleDict dict(ordereddict);
|
|
torch::Device device(torch::kCUDA, 0);
|
|
ModuleDict clone =
|
|
std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device));
|
|
for (const auto& p : clone->parameters()) {
|
|
ASSERT_EQ(p.device(), device);
|
|
}
|
|
for (const auto& b : clone->buffers()) {
|
|
ASSERT_EQ(b.device(), device);
|
|
}
|
|
}
|
|
|
|
TEST_F(ModuleDictTest, PrettyPrintModuleDict) {
|
|
torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
|
|
{"linear", Linear(10, 3).ptr()},
|
|
{"conv", Conv2d(1, 2, 3).ptr()},
|
|
{"dropout", Dropout(0.5).ptr()},
|
|
{"batch", BatchNorm2d(5).ptr()},
|
|
{"embedding", Embedding(4, 10).ptr()},
|
|
{"lstm", LSTM(4, 5).ptr()}};
|
|
ModuleDict dict(ordereddict);
|
|
|
|
ASSERT_EQ(
|
|
c10::str(dict),
|
|
"torch::nn::ModuleDict(\n"
|
|
" (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
|
|
" (conv): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
|
|
" (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
|
" (batch): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
|
" (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
|
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
|
|
")");
|
|
}
|
|
|
|
TEST_F(ModuleDictTest, InvalidAt) {
|
|
torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
|
|
{"linear", Linear(10, 3).ptr()}};
|
|
ModuleDict dict(ordereddict);
|
|
ASSERT_THROWS_WITH(
|
|
dict->at<torch::nn::Dropout2dImpl>("linear"), "Unable to cast module");
|
|
}
|