mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Let's have some fun. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78828 Approved by: https://github.com/ezyang
674 lines
23 KiB
C++
674 lines
23 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <c10/util/irange.h>
|
|
#include <torch/torch.h>
|
|
|
|
#include <algorithm>
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
using namespace torch::nn;
|
|
using namespace torch::test;
|
|
|
|
struct SequentialTest : torch::test::SeedingFixture {};
|
|
|
|
TEST_F(SequentialTest, CanContainThings) {
|
|
Sequential sequential(Linear(3, 4), ReLU(), BatchNorm1d(3));
|
|
}
|
|
|
|
TEST_F(SequentialTest, ConstructsFromSharedPointer) {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int value;
|
|
int forward() {
|
|
return value;
|
|
}
|
|
};
|
|
Sequential sequential(
|
|
std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
|
|
ASSERT_EQ(sequential->size(), 3);
|
|
|
|
Sequential sequential_named(
|
|
{{"m1", std::make_shared<M>(1)},
|
|
{std::string("m2"), std::make_shared<M>(2)},
|
|
{"m3", std::make_shared<M>(3)}});
|
|
ASSERT_EQ(sequential->size(), 3);
|
|
}
|
|
|
|
TEST_F(SequentialTest, ConstructsFromConcreteType) {
|
|
static int copy_count;
|
|
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
M(const M& other) : torch::nn::Module(other) {
|
|
copy_count++;
|
|
}
|
|
int value;
|
|
int forward() {
|
|
return value;
|
|
}
|
|
};
|
|
|
|
copy_count = 0;
|
|
Sequential sequential(M(1), M(2), M(3));
|
|
ASSERT_EQ(sequential->size(), 3);
|
|
// NOTE: The current implementation expects each module to be copied exactly
|
|
// once, which happens when the module is passed into `std::make_shared<T>()`.
|
|
// TODO: Find a way to avoid copying, and then delete the copy constructor of
|
|
// `M`.
|
|
ASSERT_EQ(copy_count, 3);
|
|
|
|
copy_count = 0;
|
|
Sequential sequential_named(
|
|
{{"m1", M(1)}, {std::string("m2"), M(2)}, {"m3", M(3)}});
|
|
ASSERT_EQ(sequential->size(), 3);
|
|
ASSERT_EQ(copy_count, 3);
|
|
}
|
|
|
|
TEST_F(SequentialTest, ConstructsFromModuleHolder) {
|
|
struct MImpl : torch::nn::Module {
|
|
explicit MImpl(int value_) : value(value_) {}
|
|
int forward() {
|
|
return value;
|
|
}
|
|
int value;
|
|
};
|
|
|
|
struct M : torch::nn::ModuleHolder<MImpl> {
|
|
using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
|
|
using torch::nn::ModuleHolder<MImpl>::get;
|
|
};
|
|
|
|
Sequential sequential(M(1), M(2), M(3));
|
|
ASSERT_EQ(sequential->size(), 3);
|
|
|
|
Sequential sequential_named(
|
|
{{"m1", M(1)}, {std::string("m2"), M(2)}, {"m3", M(3)}});
|
|
ASSERT_EQ(sequential->size(), 3);
|
|
}
|
|
|
|
TEST_F(SequentialTest, PushBackAddsAnElement) {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int forward() {
|
|
return value;
|
|
}
|
|
int value;
|
|
};
|
|
|
|
// Test unnamed submodules
|
|
Sequential sequential;
|
|
ASSERT_EQ(sequential->size(), 0);
|
|
ASSERT_TRUE(sequential->is_empty());
|
|
sequential->push_back(Linear(3, 4));
|
|
ASSERT_EQ(sequential->size(), 1);
|
|
sequential->push_back(std::make_shared<M>(1));
|
|
ASSERT_EQ(sequential->size(), 2);
|
|
sequential->push_back(M(2));
|
|
ASSERT_EQ(sequential->size(), 3);
|
|
|
|
// Mix named and unnamed submodules
|
|
Sequential sequential_named;
|
|
ASSERT_EQ(sequential_named->size(), 0);
|
|
ASSERT_TRUE(sequential_named->is_empty());
|
|
|
|
sequential_named->push_back(Linear(3, 4));
|
|
ASSERT_EQ(sequential_named->size(), 1);
|
|
ASSERT_EQ(sequential_named->named_children()[0].key(), "0");
|
|
sequential_named->push_back(std::string("linear2"), Linear(3, 4));
|
|
ASSERT_EQ(sequential_named->size(), 2);
|
|
ASSERT_EQ(sequential_named->named_children()[1].key(), "linear2");
|
|
|
|
sequential_named->push_back("shared_m1", std::make_shared<M>(1));
|
|
ASSERT_EQ(sequential_named->size(), 3);
|
|
ASSERT_EQ(sequential_named->named_children()[2].key(), "shared_m1");
|
|
sequential_named->push_back(std::make_shared<M>(1));
|
|
ASSERT_EQ(sequential_named->size(), 4);
|
|
ASSERT_EQ(sequential_named->named_children()[3].key(), "3");
|
|
|
|
sequential_named->push_back(M(1));
|
|
ASSERT_EQ(sequential_named->size(), 5);
|
|
ASSERT_EQ(sequential_named->named_children()[4].key(), "4");
|
|
sequential_named->push_back(std::string("m2"), M(1));
|
|
ASSERT_EQ(sequential_named->size(), 6);
|
|
ASSERT_EQ(sequential_named->named_children()[5].key(), "m2");
|
|
|
|
// named and unnamed AnyModule's
|
|
Sequential sequential_any;
|
|
auto a = torch::nn::AnyModule(torch::nn::Linear(1, 2));
|
|
ASSERT_EQ(sequential_any->size(), 0);
|
|
ASSERT_TRUE(sequential_any->is_empty());
|
|
sequential_any->push_back(a);
|
|
ASSERT_EQ(sequential_any->size(), 1);
|
|
ASSERT_EQ(sequential_any->named_children()[0].key(), "0");
|
|
sequential_any->push_back("fc", a);
|
|
ASSERT_EQ(sequential_any->size(), 2);
|
|
ASSERT_EQ(sequential_any->named_children()[1].key(), "fc");
|
|
}
|
|
|
|
TEST_F(SequentialTest, AccessWithAt) {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int forward() {
|
|
return value;
|
|
}
|
|
int value;
|
|
};
|
|
std::vector<std::shared_ptr<M>> modules = {
|
|
std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
|
|
|
|
Sequential sequential;
|
|
for (auto& module : modules) {
|
|
sequential->push_back(module);
|
|
}
|
|
ASSERT_EQ(sequential->size(), 3);
|
|
|
|
// returns the correct module for a given index
|
|
for (const auto i : c10::irange(modules.size())) {
|
|
ASSERT_EQ(&sequential->at<M>(i), modules[i].get());
|
|
}
|
|
|
|
// throws for a bad index
|
|
ASSERT_THROWS_WITH(
|
|
sequential->at<M>(modules.size() + 1), "Index out of range");
|
|
ASSERT_THROWS_WITH(
|
|
sequential->at<M>(modules.size() + 1000000), "Index out of range");
|
|
}
|
|
|
|
TEST_F(SequentialTest, AccessWithPtr) {
|
|
struct M : torch::nn::Module {
|
|
explicit M(int value_) : value(value_) {}
|
|
int forward() {
|
|
return value;
|
|
}
|
|
int value;
|
|
};
|
|
std::vector<std::shared_ptr<M>> modules = {
|
|
std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
|
|
|
|
Sequential sequential;
|
|
for (auto& module : modules) {
|
|
sequential->push_back(module);
|
|
}
|
|
ASSERT_EQ(sequential->size(), 3);
|
|
|
|
// returns the correct module for a given index
|
|
for (const auto i : c10::irange(modules.size())) {
|
|
ASSERT_EQ(sequential->ptr(i).get(), modules[i].get());
|
|
ASSERT_EQ(sequential[i].get(), modules[i].get());
|
|
ASSERT_EQ(sequential->ptr<M>(i).get(), modules[i].get());
|
|
}
|
|
|
|
// throws for a bad index
|
|
ASSERT_THROWS_WITH(sequential->ptr(modules.size() + 1), "Index out of range");
|
|
ASSERT_THROWS_WITH(
|
|
sequential->ptr(modules.size() + 1000000), "Index out of range");
|
|
}
|
|
|
|
TEST_F(SequentialTest, CallingForwardOnEmptySequentialIsDisallowed) {
|
|
Sequential empty;
|
|
ASSERT_THROWS_WITH(
|
|
empty->forward<int>(), "Cannot call forward() on an empty Sequential");
|
|
}
|
|
|
|
TEST_F(SequentialTest, CallingForwardChainsCorrectly) {
|
|
struct MockModule : torch::nn::Module {
|
|
explicit MockModule(int value) : expected(value) {}
|
|
int expected;
|
|
int forward(int value) {
|
|
assert(value == expected);
|
|
return value + 1;
|
|
}
|
|
};
|
|
|
|
Sequential sequential(MockModule{1}, MockModule{2}, MockModule{3});
|
|
|
|
ASSERT_EQ(sequential->forward<int>(1), 4);
|
|
}
|
|
|
|
TEST_F(SequentialTest, CallingForwardWithTheWrongReturnTypeThrows) {
|
|
struct M : public torch::nn::Module {
|
|
int forward() {
|
|
return 5;
|
|
}
|
|
};
|
|
|
|
Sequential sequential(M{});
|
|
ASSERT_EQ(sequential->forward<int>(), 5);
|
|
ASSERT_THROWS_WITH(
|
|
sequential->forward<float>(),
|
|
"The type of the return value is int, but you asked for type float");
|
|
}
|
|
|
|
TEST_F(SequentialTest, TheReturnTypeOfForwardDefaultsToTensor) {
|
|
struct M : public torch::nn::Module {
|
|
torch::Tensor forward(torch::Tensor v) {
|
|
return v;
|
|
}
|
|
};
|
|
|
|
Sequential sequential(M{});
|
|
auto variable = torch::ones({3, 3}, torch::requires_grad());
|
|
ASSERT_TRUE(sequential->forward(variable).equal(variable));
|
|
}
|
|
|
|
TEST_F(SequentialTest, ForwardReturnsTheLastValue) {
|
|
torch::manual_seed(0);
|
|
Sequential sequential(Linear(10, 3), Linear(3, 5), Linear(5, 100));
|
|
|
|
auto x = torch::randn({1000, 10}, torch::requires_grad());
|
|
auto y = sequential->forward(x);
|
|
ASSERT_EQ(y.ndimension(), 2);
|
|
ASSERT_EQ(y.size(0), 1000);
|
|
ASSERT_EQ(y.size(1), 100);
|
|
}
|
|
|
|
TEST_F(SequentialTest, SanityCheckForHoldingStandardModules) {
|
|
Sequential sequential(
|
|
Linear(10, 3),
|
|
Conv2d(1, 2, 3),
|
|
Dropout(0.5),
|
|
BatchNorm2d(5),
|
|
Embedding(4, 10),
|
|
LSTM(4, 5));
|
|
}
|
|
|
|
TEST_F(SequentialTest, ExtendPushesModulesFromOtherSequential) {
|
|
struct A : torch::nn::Module {
|
|
int forward(int x) {
|
|
return x;
|
|
}
|
|
};
|
|
struct B : torch::nn::Module {
|
|
int forward(int x) {
|
|
return x;
|
|
}
|
|
};
|
|
struct C : torch::nn::Module {
|
|
int forward(int x) {
|
|
return x;
|
|
}
|
|
};
|
|
struct D : torch::nn::Module {
|
|
int forward(int x) {
|
|
return x;
|
|
}
|
|
};
|
|
Sequential a(A{}, B{});
|
|
Sequential b(C{}, D{});
|
|
a->extend(*b);
|
|
|
|
ASSERT_EQ(a->size(), 4);
|
|
ASSERT_TRUE(a[0]->as<A>());
|
|
ASSERT_TRUE(a[1]->as<B>());
|
|
ASSERT_TRUE(a[2]->as<C>());
|
|
ASSERT_TRUE(a[3]->as<D>());
|
|
|
|
ASSERT_EQ(b->size(), 2);
|
|
ASSERT_TRUE(b[0]->as<C>());
|
|
ASSERT_TRUE(b[1]->as<D>());
|
|
|
|
std::vector<std::shared_ptr<A>> c = {
|
|
std::make_shared<A>(), std::make_shared<A>()};
|
|
b->extend(c);
|
|
|
|
ASSERT_EQ(b->size(), 4);
|
|
ASSERT_TRUE(b[0]->as<C>());
|
|
ASSERT_TRUE(b[1]->as<D>());
|
|
ASSERT_TRUE(b[2]->as<A>());
|
|
ASSERT_TRUE(b[3]->as<A>());
|
|
}
|
|
|
|
TEST_F(SequentialTest, HasReferenceSemantics) {
|
|
Sequential first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
|
|
Sequential second(first);
|
|
|
|
ASSERT_EQ(first.get(), second.get());
|
|
ASSERT_EQ(first->size(), second->size());
|
|
ASSERT_TRUE(std::equal(
|
|
first->begin(),
|
|
first->end(),
|
|
second->begin(),
|
|
[](const AnyModule& first, const AnyModule& second) {
|
|
return &first == &second;
|
|
}));
|
|
}
|
|
|
|
TEST_F(SequentialTest, IsCloneable) {
|
|
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
|
|
Sequential clone =
|
|
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone());
|
|
ASSERT_EQ(sequential->size(), clone->size());
|
|
|
|
for (size_t i = 0; i < sequential->size(); ++i) {
|
|
// The modules should be the same kind (type).
|
|
ASSERT_EQ(sequential[i]->name(), clone[i]->name());
|
|
// But not pointer-equal (distinct objects).
|
|
ASSERT_NE(sequential[i], clone[i]);
|
|
}
|
|
|
|
// Verify that the clone is deep, i.e. parameters of modules are cloned too.
|
|
|
|
torch::NoGradGuard no_grad;
|
|
|
|
auto params1 = sequential->named_parameters();
|
|
auto params2 = clone->named_parameters();
|
|
ASSERT_EQ(params1.size(), params2.size());
|
|
for (auto& param : params1) {
|
|
ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
|
|
ASSERT_EQ(param->device(), params2[param.key()].device());
|
|
ASSERT_TRUE(param->allclose(params2[param.key()]));
|
|
param->add_(2);
|
|
}
|
|
for (auto& param : params1) {
|
|
ASSERT_FALSE(param->allclose(params2[param.key()]));
|
|
}
|
|
}
|
|
|
|
TEST_F(SequentialTest, RegistersElementsAsSubmodules) {
|
|
Sequential sequential(Linear(10, 3), Conv2d(1, 2, 3), Dropout2d(0.5));
|
|
|
|
auto modules = sequential->children();
|
|
ASSERT_TRUE(modules[0]->as<Linear>());
|
|
ASSERT_TRUE(modules[1]->as<Conv2d>());
|
|
ASSERT_TRUE(modules[2]->as<Dropout2d>());
|
|
}
|
|
|
|
TEST_F(SequentialTest, CloneToDevice_CUDA) {
|
|
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
|
|
torch::Device device(torch::kCUDA, 0);
|
|
Sequential clone =
|
|
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device));
|
|
for (const auto& p : clone->parameters()) {
|
|
ASSERT_EQ(p.device(), device);
|
|
}
|
|
for (const auto& b : clone->buffers()) {
|
|
ASSERT_EQ(b.device(), device);
|
|
}
|
|
}
|
|
|
|
TEST_F(SequentialTest, PrettyPrintSequential) {
|
|
Sequential sequential(
|
|
Linear(10, 3),
|
|
Conv2d(1, 2, 3),
|
|
Dropout(0.5),
|
|
BatchNorm2d(5),
|
|
Embedding(4, 10),
|
|
LSTM(4, 5));
|
|
ASSERT_EQ(
|
|
c10::str(sequential),
|
|
"torch::nn::Sequential(\n"
|
|
" (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
|
|
" (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
|
|
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
|
" (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
|
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
|
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
|
|
")");
|
|
|
|
Sequential sequential_named(
|
|
{{"linear", Linear(10, 3)},
|
|
{"conv2d", Conv2d(1, 2, 3)},
|
|
{"dropout", Dropout(0.5)},
|
|
{"batchnorm2d", BatchNorm2d(5)},
|
|
{"embedding", Embedding(4, 10)},
|
|
{"lstm", LSTM(4, 5)}});
|
|
ASSERT_EQ(
|
|
c10::str(sequential_named),
|
|
"torch::nn::Sequential(\n"
|
|
" (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
|
|
" (conv2d): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
|
|
" (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
|
" (batchnorm2d): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
|
" (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
|
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
|
|
")");
|
|
}
|
|
|
|
TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) {
|
|
{
|
|
Sequential sequential(
|
|
Identity(),
|
|
ConvTranspose1d(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false)));
|
|
std::dynamic_pointer_cast<ConvTranspose1dImpl>(sequential[1])
|
|
->weight.set_data(torch::arange(18.).reshape({3, 2, 3}));
|
|
auto x = torch::arange(30.).reshape({2, 3, 5});
|
|
auto y = sequential->forward(x);
|
|
auto expected = torch::tensor(
|
|
{{{150., 333., 552., 615., 678., 501., 276.},
|
|
{195., 432., 714., 804., 894., 654., 357.}},
|
|
{{420., 918., 1497., 1560., 1623., 1176., 636.},
|
|
{600., 1287., 2064., 2154., 2244., 1599., 852.}}});
|
|
ASSERT_TRUE(torch::allclose(y, expected));
|
|
}
|
|
{
|
|
Sequential sequential(
|
|
Identity(),
|
|
ConvTranspose2d(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false)));
|
|
std::dynamic_pointer_cast<ConvTranspose2dImpl>(sequential[1])
|
|
->weight.set_data(torch::arange(54.).reshape({3, 2, 3, 3}));
|
|
auto x = torch::arange(75.).reshape({1, 3, 5, 5});
|
|
auto y = sequential->forward(x);
|
|
auto expected = torch::tensor(
|
|
{{{{2250., 4629., 7140., 7311., 7482., 5133., 2640.},
|
|
{4995., 10272., 15837., 16206., 16575., 11364., 5841.},
|
|
{8280., 17019., 26226., 26820., 27414., 18783., 9648.},
|
|
{9225., 18954., 29196., 29790., 30384., 20808., 10683.},
|
|
{10170., 20889., 32166., 32760., 33354., 22833., 11718.},
|
|
{7515., 15420., 23721., 24144., 24567., 16800., 8613.},
|
|
{4140., 8487., 13044., 13269., 13494., 9219., 4722.}},
|
|
{{2925., 6006., 9246., 9498., 9750., 6672., 3423.},
|
|
{6480., 13296., 20454., 20985., 21516., 14712., 7542.},
|
|
{10710., 21960., 33759., 34596., 35433., 24210., 12402.},
|
|
{12060., 24705., 37944., 38781., 39618., 27045., 13842.},
|
|
{13410., 27450., 42129., 42966., 43803., 29880., 15282.},
|
|
{9810., 20064., 30768., 31353., 31938., 21768., 11124.},
|
|
{5355., 10944., 16770., 17076., 17382., 11838., 6045.}}}});
|
|
ASSERT_TRUE(torch::allclose(y, expected));
|
|
}
|
|
{
|
|
Sequential sequential(
|
|
Identity(),
|
|
ConvTranspose3d(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false)));
|
|
std::dynamic_pointer_cast<ConvTranspose3dImpl>(sequential[1])
|
|
->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2}));
|
|
auto x = torch::arange(16.).reshape({1, 2, 2, 2, 2});
|
|
auto y = sequential->forward(x);
|
|
auto expected = torch::tensor(
|
|
{{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}},
|
|
{{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}},
|
|
{{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}},
|
|
{{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}},
|
|
{{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}},
|
|
{{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}});
|
|
ASSERT_TRUE(torch::allclose(y, expected));
|
|
}
|
|
{
|
|
auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
|
|
Sequential sequential(Identity(), EmbeddingBag::from_pretrained(weight));
|
|
auto x = torch::tensor({{1, 0}}, torch::kLong);
|
|
auto y = sequential->forward(x);
|
|
auto expected = torch::tensor({2.5000, 3.7000, 4.6500});
|
|
ASSERT_TRUE(torch::allclose(y, expected));
|
|
}
|
|
{
|
|
torch::manual_seed(0);
|
|
|
|
int64_t embed_dim = 8;
|
|
int64_t num_heads = 4;
|
|
int64_t batch_size = 8;
|
|
int64_t src_len = 3;
|
|
int64_t tgt_len = 1;
|
|
|
|
auto query = torch::ones({batch_size, tgt_len, embed_dim});
|
|
auto key = torch::ones({batch_size, src_len, embed_dim});
|
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
|
auto value = key;
|
|
|
|
Sequential sequential(MultiheadAttention(embed_dim, num_heads));
|
|
auto output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(
|
|
query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1));
|
|
|
|
auto attn_output = std::get<0>(output);
|
|
auto attn_output_expected = torch::tensor(
|
|
{{{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
|
|
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
|
|
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
|
|
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
|
|
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
|
|
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
|
|
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
|
|
{0.0674,
|
|
-0.0056,
|
|
0.1324,
|
|
0.0922,
|
|
0.0160,
|
|
-0.0934,
|
|
-0.1700,
|
|
0.1663}}});
|
|
ASSERT_TRUE(
|
|
torch::allclose(attn_output, attn_output_expected, 1e-05, 2e-04));
|
|
|
|
auto attn_output_weights = std::get<1>(output);
|
|
auto attn_output_weights_expected = torch::tensor(
|
|
{{{0.3333, 0.3333, 0.3333}},
|
|
{{0.3333, 0.3333, 0.3333}},
|
|
{{0.3333, 0.3333, 0.3333}},
|
|
{{0.3333, 0.3333, 0.3333}},
|
|
{{0.3333, 0.3333, 0.3333}},
|
|
{{0.3333, 0.3333, 0.3333}},
|
|
{{0.3333, 0.3333, 0.3333}},
|
|
{{0.3333, 0.3333, 0.3333}}});
|
|
ASSERT_TRUE(torch::allclose(
|
|
attn_output_weights, attn_output_weights_expected, 1e-05, 2e-04));
|
|
}
|
|
{
|
|
auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
|
|
auto x = torch::tensor({{{2, 4, 5}}}, torch::dtype(torch::kFloat));
|
|
Sequential sequential(MaxUnpool1d(3));
|
|
auto y = sequential->forward(x, indices);
|
|
auto expected =
|
|
torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat);
|
|
ASSERT_TRUE(torch::allclose(y, expected));
|
|
}
|
|
{
|
|
auto indices = torch::tensor(
|
|
{{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
|
|
{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}},
|
|
torch::kLong);
|
|
auto x = torch::tensor(
|
|
{{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
|
|
{{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}},
|
|
torch::dtype(torch::kFloat));
|
|
Sequential sequential(
|
|
MaxUnpool2d(MaxUnpool2dOptions(3).stride(2).padding(1)));
|
|
auto y = sequential->forward(x, indices);
|
|
auto expected = torch::tensor(
|
|
{{{{0, 0, 0, 0, 0},
|
|
{0, 6, 0, 8, 9},
|
|
{0, 0, 0, 0, 0},
|
|
{0, 16, 0, 18, 19},
|
|
{0, 21, 0, 23, 24}}},
|
|
{{{0, 0, 0, 0, 0},
|
|
{0, 31, 0, 33, 34},
|
|
{0, 0, 0, 0, 0},
|
|
{0, 41, 0, 43, 44},
|
|
{0, 46, 0, 48, 49}}}},
|
|
torch::kFloat);
|
|
ASSERT_TRUE(torch::allclose(y, expected));
|
|
}
|
|
{
|
|
auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
|
|
auto x = torch::tensor(
|
|
{{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
Sequential sequential(MaxUnpool3d(3));
|
|
auto y = sequential->forward(x, indices);
|
|
auto expected = torch::tensor(
|
|
{{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
|
|
{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
|
|
{{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}},
|
|
torch::kFloat);
|
|
ASSERT_TRUE(torch::allclose(y, expected));
|
|
}
|
|
{
|
|
torch::manual_seed(0);
|
|
Sequential sequential(Identity(), RNN(2, 3));
|
|
auto x = torch::ones({2, 3, 2});
|
|
auto rnn_output =
|
|
sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
|
|
auto expected_output = torch::tensor(
|
|
{{{-0.0645, -0.7274, 0.4531},
|
|
{-0.0645, -0.7274, 0.4531},
|
|
{-0.0645, -0.7274, 0.4531}},
|
|
{{-0.3970, -0.6950, 0.6009},
|
|
{-0.3970, -0.6950, 0.6009},
|
|
{-0.3970, -0.6950, 0.6009}}});
|
|
ASSERT_TRUE(torch::allclose(
|
|
std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
|
|
}
|
|
{
|
|
torch::manual_seed(0);
|
|
Sequential sequential(Identity(), LSTM(2, 3));
|
|
auto x = torch::ones({2, 3, 2});
|
|
auto rnn_output = sequential->forward<
|
|
std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>>(x);
|
|
auto expected_output = torch::tensor(
|
|
{{{-0.2693, -0.1240, 0.0744},
|
|
{-0.2693, -0.1240, 0.0744},
|
|
{-0.2693, -0.1240, 0.0744}},
|
|
{{-0.3889, -0.1919, 0.1183},
|
|
{-0.3889, -0.1919, 0.1183},
|
|
{-0.3889, -0.1919, 0.1183}}});
|
|
ASSERT_TRUE(torch::allclose(
|
|
std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
|
|
}
|
|
{
|
|
torch::manual_seed(0);
|
|
Sequential sequential(Identity(), GRU(2, 3));
|
|
auto x = torch::ones({2, 3, 2});
|
|
auto rnn_output =
|
|
sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
|
|
auto expected_output = torch::tensor(
|
|
{{{-0.1134, 0.0467, 0.2336},
|
|
{-0.1134, 0.0467, 0.2336},
|
|
{-0.1134, 0.0467, 0.2336}},
|
|
{{-0.1189, 0.0502, 0.2960},
|
|
{-0.1189, 0.0502, 0.2960},
|
|
{-0.1189, 0.0502, 0.2960}}});
|
|
ASSERT_TRUE(torch::allclose(
|
|
std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
|
|
}
|
|
{
|
|
torch::manual_seed(0);
|
|
Sequential sequential(Identity(), RNNCell(2, 3));
|
|
auto x = torch::ones({2, 2});
|
|
auto rnn_output = sequential->forward<torch::Tensor>(x);
|
|
auto expected_output =
|
|
torch::tensor({{-0.0645, -0.7274, 0.4531}, {-0.0645, -0.7274, 0.4531}});
|
|
ASSERT_TRUE(torch::allclose(rnn_output, expected_output, 1e-05, 2e-04));
|
|
}
|
|
{
|
|
torch::manual_seed(0);
|
|
Sequential sequential(Identity(), LSTMCell(2, 3));
|
|
auto x = torch::ones({2, 2});
|
|
auto rnn_output =
|
|
sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
|
|
auto expected_output =
|
|
torch::tensor({{-0.2693, -0.1240, 0.0744}, {-0.2693, -0.1240, 0.0744}});
|
|
ASSERT_TRUE(torch::allclose(
|
|
std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
|
|
}
|
|
{
|
|
torch::manual_seed(0);
|
|
Sequential sequential(Identity(), GRUCell(2, 3));
|
|
auto x = torch::ones({2, 2});
|
|
auto rnn_output = sequential->forward<torch::Tensor>(x);
|
|
auto expected_output =
|
|
torch::tensor({{-0.1134, 0.0467, 0.2336}, {-0.1134, 0.0467, 0.2336}});
|
|
ASSERT_TRUE(torch::allclose(rnn_output, expected_output, 1e-05, 2e-04));
|
|
}
|
|
}
|