mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 07:27:32 +08:00
Rewrite C++ API tests in gtest (#11953)
Summary: This PR is a large codemod to rewrite all C++ API tests with GoogleTest (gtest) instead of Catch. You can largely trust me to have correctly code-modded the tests, so it's not required to review every of the 2000+ changed lines. However, additional things I changed were: 1. Moved the cmake parts for these tests into their own `CMakeLists.txt` under `test/cpp/api` and calling `add_subdirectory` from `torch/CMakeLists.txt` 2. Fixing DataParallel tests which weren't being compiled because `USE_CUDA` wasn't correctly being set at all. 3. Updated README ezyang ebetica Pull Request resolved: https://github.com/pytorch/pytorch/pull/11953 Differential Revision: D9998883 Pulled By: goldsborough fbshipit-source-id: affe3f320b0ca63e7e0019926a59076bb943db80
This commit is contained in:
committed by
Facebook Github Bot
parent
d0db23e95a
commit
825181ea9d
@ -1,4 +1,4 @@
|
||||
#include "catch_utils.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/nn/modules/batchnorm.h>
|
||||
@ -10,9 +10,7 @@
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <test/cpp/api/util.h>
|
||||
|
||||
using Catch::StartsWith;
|
||||
#include <test/cpp/api/support.h>
|
||||
|
||||
using namespace torch::nn;
|
||||
using namespace torch::test;
|
||||
@ -39,303 +37,292 @@ class NestedModel : public torch::nn::Module {
|
||||
std::shared_ptr<TestModel> t;
|
||||
};
|
||||
|
||||
CATCH_TEST_CASE("modules") {
|
||||
torch::manual_seed(0);
|
||||
CATCH_SECTION("conv") {
|
||||
CATCH_SECTION("1d") {
|
||||
Conv1d model(Conv1dOptions(3, 2, 3).stride(2));
|
||||
auto x = torch::randn({2, 3, 5}, torch::requires_grad());
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
struct ModulesTest : torch::test::SeedingFixture {};
|
||||
|
||||
s.backward();
|
||||
CATCH_REQUIRE(y.ndimension() == 3);
|
||||
CATCH_REQUIRE(s.ndimension() == 0);
|
||||
for (auto i = 0; i < 3; i++) {
|
||||
CATCH_REQUIRE(y.size(i) == 2);
|
||||
}
|
||||
TEST_F(ModulesTest, Conv1d) {
|
||||
Conv1d model(Conv1dOptions(3, 2, 3).stride(2));
|
||||
auto x = torch::randn({2, 3, 5}, torch::requires_grad());
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3);
|
||||
}
|
||||
CATCH_SECTION("2d") {
|
||||
CATCH_SECTION("even") {
|
||||
Conv2d model(Conv2dOptions(3, 2, 3).stride(2));
|
||||
auto x = torch::randn({2, 3, 5, 5}, torch::requires_grad());
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
CATCH_REQUIRE(y.ndimension() == 4);
|
||||
CATCH_REQUIRE(s.ndimension() == 0);
|
||||
for (auto i = 0; i < 4; i++) {
|
||||
CATCH_REQUIRE(y.size(i) == 2);
|
||||
}
|
||||
|
||||
CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 3);
|
||||
}
|
||||
|
||||
CATCH_SECTION("uneven") {
|
||||
Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({2, 2}));
|
||||
auto x = torch::randn({2, 3, 5, 4}, torch::requires_grad());
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
CATCH_REQUIRE(y.ndimension() == 4);
|
||||
CATCH_REQUIRE(s.ndimension() == 0);
|
||||
for (auto i = 0; i < 4; i++) {
|
||||
CATCH_REQUIRE(y.size(i) == 2);
|
||||
}
|
||||
|
||||
CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 2);
|
||||
}
|
||||
}
|
||||
CATCH_SECTION("3d") {
|
||||
Conv3d model(Conv3dOptions(3, 2, 3).stride(2));
|
||||
auto x = torch::randn({2, 3, 5, 5, 5}, torch::requires_grad());
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
CATCH_REQUIRE(y.ndimension() == 5);
|
||||
CATCH_REQUIRE(s.ndimension() == 0);
|
||||
for (auto i = 0; i < 5; i++) {
|
||||
CATCH_REQUIRE(y.size(i) == 2);
|
||||
}
|
||||
|
||||
CATCH_REQUIRE(
|
||||
model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 3 * 3);
|
||||
}
|
||||
}
|
||||
CATCH_SECTION("linear") {
|
||||
CATCH_SECTION("basic1") {
|
||||
Linear model(5, 2);
|
||||
auto x = torch::randn({10, 5}, torch::requires_grad());
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
CATCH_REQUIRE(y.ndimension() == 2);
|
||||
CATCH_REQUIRE(s.ndimension() == 0);
|
||||
CATCH_REQUIRE(y.size(0) == 10);
|
||||
CATCH_REQUIRE(y.size(1) == 2);
|
||||
|
||||
CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
|
||||
}
|
||||
s.backward();
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
for (auto i = 0; i < 3; i++) {
|
||||
ASSERT_EQ(y.size(i), 2);
|
||||
}
|
||||
|
||||
CATCH_SECTION("simple") {
|
||||
auto model = std::make_shared<SimpleContainer>();
|
||||
auto l1 = model->add(Linear(10, 3), "l1");
|
||||
auto l2 = model->add(Linear(3, 5), "l2");
|
||||
auto l3 = model->add(Linear(5, 100), "l3");
|
||||
|
||||
auto x = torch::randn({1000, 10}, torch::requires_grad());
|
||||
x = l1->forward(x).clamp_min(0);
|
||||
x = l2->forward(x).clamp_min(0);
|
||||
x = l3->forward(x).clamp_min(0);
|
||||
|
||||
x.backward();
|
||||
CATCH_REQUIRE(x.ndimension() == 2);
|
||||
CATCH_REQUIRE(x.size(0) == 1000);
|
||||
CATCH_REQUIRE(x.size(1) == 100);
|
||||
CATCH_REQUIRE(x.min().toCFloat() == 0);
|
||||
}
|
||||
|
||||
CATCH_SECTION("embedding") {
|
||||
CATCH_SECTION("basic") {
|
||||
const int64_t dict_size = 10;
|
||||
Embedding model(dict_size, 2);
|
||||
CATCH_REQUIRE(model->parameters().contains("weight"));
|
||||
CATCH_REQUIRE(model->weight.ndimension() == 2);
|
||||
CATCH_REQUIRE(model->weight.size(0) == dict_size);
|
||||
CATCH_REQUIRE(model->weight.size(1) == 2);
|
||||
|
||||
// Cannot get gradients to change indices (input) - only for embedding
|
||||
// params
|
||||
auto x = torch::full({10}, dict_size - 1, torch::kInt64);
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
CATCH_REQUIRE(y.ndimension() == 2);
|
||||
CATCH_REQUIRE(s.ndimension() == 0);
|
||||
CATCH_REQUIRE(y.size(0) == 10);
|
||||
CATCH_REQUIRE(y.size(1) == 2);
|
||||
|
||||
CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 2 * dict_size);
|
||||
}
|
||||
|
||||
CATCH_SECTION("list") {
|
||||
Embedding model(6, 4);
|
||||
auto x = torch::full({2, 3}, 5, torch::kInt64);
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
CATCH_REQUIRE(y.ndimension() == 3);
|
||||
CATCH_REQUIRE(y.size(0) == 2);
|
||||
CATCH_REQUIRE(y.size(1) == 3);
|
||||
CATCH_REQUIRE(y.size(2) == 4);
|
||||
}
|
||||
}
|
||||
|
||||
CATCH_SECTION("dropout") {
|
||||
Dropout dropout(0.5);
|
||||
torch::Tensor x = torch::ones(100, torch::requires_grad());
|
||||
torch::Tensor y = dropout->forward(x);
|
||||
|
||||
y.backward();
|
||||
CATCH_REQUIRE(y.ndimension() == 1);
|
||||
CATCH_REQUIRE(y.size(0) == 100);
|
||||
CATCH_REQUIRE(y.sum().toCFloat() < 130); // Probably
|
||||
CATCH_REQUIRE(y.sum().toCFloat() > 70); // Probably
|
||||
|
||||
dropout->eval();
|
||||
y = dropout->forward(x);
|
||||
CATCH_REQUIRE(y.sum().toCFloat() == 100);
|
||||
}
|
||||
|
||||
CATCH_SECTION("param") {
|
||||
auto model = std::make_shared<NestedModel>();
|
||||
auto parameters = model->parameters();
|
||||
CATCH_REQUIRE(parameters["param"].size(0) == 3);
|
||||
CATCH_REQUIRE(parameters["param"].size(1) == 2);
|
||||
CATCH_REQUIRE(parameters["param"].size(2) == 21);
|
||||
CATCH_REQUIRE(parameters["l1.bias"].size(0) == 20);
|
||||
CATCH_REQUIRE(parameters["l1.weight"].size(0) == 20);
|
||||
CATCH_REQUIRE(parameters["l1.weight"].size(1) == 5);
|
||||
CATCH_REQUIRE(parameters["test.l1.bias"].size(0) == 3);
|
||||
CATCH_REQUIRE(parameters["test.l1.weight"].size(0) == 3);
|
||||
CATCH_REQUIRE(parameters["test.l1.weight"].size(1) == 10);
|
||||
CATCH_REQUIRE(parameters["test.l2.bias"].size(0) == 5);
|
||||
CATCH_REQUIRE(parameters["test.l2.weight"].size(0) == 5);
|
||||
CATCH_REQUIRE(parameters["test.l2.weight"].size(1) == 3);
|
||||
CATCH_REQUIRE(parameters["test.l3.bias"].size(0) == 100);
|
||||
CATCH_REQUIRE(parameters["test.l3.weight"].size(0) == 100);
|
||||
CATCH_REQUIRE(parameters["test.l3.weight"].size(1) == 5);
|
||||
}
|
||||
|
||||
CATCH_SECTION("functional") {
|
||||
{
|
||||
bool was_called = false;
|
||||
auto functional = Functional([&was_called](torch::Tensor input) {
|
||||
was_called = true;
|
||||
return input;
|
||||
});
|
||||
auto output = functional->forward(torch::ones(5, torch::requires_grad()));
|
||||
CATCH_REQUIRE(was_called);
|
||||
CATCH_REQUIRE(output.equal(torch::ones(5, torch::requires_grad())));
|
||||
|
||||
was_called = false;
|
||||
// Use the call operator overload here.
|
||||
output = functional(torch::ones(5, torch::requires_grad()));
|
||||
CATCH_REQUIRE(was_called);
|
||||
CATCH_REQUIRE(output.equal(torch::ones(5, torch::requires_grad())));
|
||||
}
|
||||
{
|
||||
auto functional = Functional(torch::relu);
|
||||
CATCH_REQUIRE(functional(torch::ones({})).toCFloat() == 1);
|
||||
CATCH_REQUIRE(functional(torch::ones({})).toCFloat() == 1);
|
||||
CATCH_REQUIRE(functional(torch::ones({}) * -1).toCFloat() == 0);
|
||||
}
|
||||
{
|
||||
auto functional =
|
||||
Functional(torch::elu, /*alpha=*/1, /*scale=*/0, /*input_scale=*/1);
|
||||
CATCH_REQUIRE(functional(torch::ones({})).toCFloat() == 0);
|
||||
}
|
||||
}
|
||||
|
||||
CATCH_SECTION("batchnorm") {
|
||||
{
|
||||
BatchNorm bn(5);
|
||||
|
||||
// Is stateful by default.
|
||||
CATCH_REQUIRE(bn->options.stateful());
|
||||
|
||||
CATCH_REQUIRE(bn->running_mean.defined());
|
||||
CATCH_REQUIRE(bn->running_mean.dim() == 1);
|
||||
CATCH_REQUIRE(bn->running_mean.size(0) == 5);
|
||||
|
||||
CATCH_REQUIRE(bn->running_variance.defined());
|
||||
CATCH_REQUIRE(bn->running_variance.dim() == 1);
|
||||
CATCH_REQUIRE(bn->running_variance.size(0) == 5);
|
||||
|
||||
// Is affine by default.
|
||||
CATCH_REQUIRE(bn->options.affine());
|
||||
|
||||
CATCH_REQUIRE(bn->weight.defined());
|
||||
CATCH_REQUIRE(bn->weight.dim() == 1);
|
||||
CATCH_REQUIRE(bn->weight.size(0) == 5);
|
||||
|
||||
CATCH_REQUIRE(bn->bias.defined());
|
||||
CATCH_REQUIRE(bn->bias.dim() == 1);
|
||||
CATCH_REQUIRE(bn->bias.size(0) == 5);
|
||||
}
|
||||
{
|
||||
BatchNorm bn(BatchNormOptions(5).stateful(false).affine(false));
|
||||
|
||||
CATCH_REQUIRE(!bn->running_mean.defined());
|
||||
CATCH_REQUIRE(!bn->running_variance.defined());
|
||||
CATCH_REQUIRE(!bn->weight.defined());
|
||||
CATCH_REQUIRE(!bn->bias.defined());
|
||||
|
||||
CATCH_REQUIRE_THROWS_WITH(
|
||||
bn->forward(torch::ones({2, 5})),
|
||||
StartsWith("Calling BatchNorm::forward is only permitted "
|
||||
"when the 'stateful' option is true (was false). "
|
||||
"Use BatchNorm::pure_forward instead."));
|
||||
}
|
||||
{
|
||||
BatchNorm bn(BatchNormOptions(5).affine(false));
|
||||
bn->eval();
|
||||
|
||||
// Want to make sure we use the supplied values in `pure_forward` even if
|
||||
// we are stateful.
|
||||
auto input = torch::randn({2, 5});
|
||||
auto mean = torch::randn(5);
|
||||
auto variance = torch::rand(5);
|
||||
auto output = bn->pure_forward(input, mean, variance);
|
||||
auto expected =
|
||||
(input - mean) / torch::sqrt(variance + bn->options.eps());
|
||||
CATCH_REQUIRE(output.allclose(expected));
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(model->parameters()["weight"].grad().numel(), 3 * 2 * 3);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("modules_cuda", "[cuda]") {
|
||||
torch::manual_seed(0);
|
||||
CATCH_SECTION("1") {
|
||||
Linear model(5, 2);
|
||||
model->to(torch::kCUDA);
|
||||
auto x =
|
||||
torch::randn({10, 5}, torch::device(torch::kCUDA).requires_grad(true));
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
TEST_F(ModulesTest, Conv2dEven) {
|
||||
Conv2d model(Conv2dOptions(3, 2, 3).stride(2));
|
||||
auto x = torch::randn({2, 3, 5, 5}, torch::requires_grad());
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
CATCH_REQUIRE(y.ndimension() == 2);
|
||||
CATCH_REQUIRE(s.ndimension() == 0);
|
||||
CATCH_REQUIRE(y.size(0) == 10);
|
||||
CATCH_REQUIRE(y.size(1) == 2);
|
||||
|
||||
CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
|
||||
s.backward();
|
||||
ASSERT_EQ(y.ndimension(), 4);
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
for (auto i = 0; i < 4; i++) {
|
||||
ASSERT_EQ(y.size(i), 2);
|
||||
}
|
||||
|
||||
CATCH_SECTION("2") {
|
||||
Linear model(5, 2);
|
||||
model->to(torch::kCUDA);
|
||||
model->to(torch::kCPU);
|
||||
auto x = torch::randn({10, 5}, torch::requires_grad());
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
CATCH_REQUIRE(y.ndimension() == 2);
|
||||
CATCH_REQUIRE(s.ndimension() == 0);
|
||||
CATCH_REQUIRE(y.size(0) == 10);
|
||||
CATCH_REQUIRE(y.size(1) == 2);
|
||||
|
||||
CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
|
||||
}
|
||||
ASSERT_EQ(model->parameters()["weight"].grad().numel(), 3 * 2 * 3 * 3);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Conv2dUneven) {
|
||||
Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({2, 2}));
|
||||
auto x = torch::randn({2, 3, 5, 4}, torch::requires_grad());
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(y.ndimension(), 4);
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
for (auto i = 0; i < 4; i++) {
|
||||
ASSERT_EQ(y.size(i), 2);
|
||||
}
|
||||
|
||||
ASSERT_EQ(model->parameters()["weight"].grad().numel(), 3 * 2 * 3 * 2);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Conv3d) {
|
||||
Conv3d model(Conv3dOptions(3, 2, 3).stride(2));
|
||||
auto x = torch::randn({2, 3, 5, 5, 5}, torch::requires_grad());
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(y.ndimension(), 5);
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
for (auto i = 0; i < 5; i++) {
|
||||
ASSERT_EQ(y.size(i), 2);
|
||||
}
|
||||
|
||||
ASSERT_TRUE(
|
||||
model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 3 * 3);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Linear) {
|
||||
Linear model(5, 2);
|
||||
auto x = torch::randn({10, 5}, torch::requires_grad());
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(y.ndimension(), 2);
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(y.size(0), 10);
|
||||
ASSERT_EQ(y.size(1), 2);
|
||||
|
||||
ASSERT_EQ(model->parameters()["weight"].grad().numel(), 2 * 5);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, SimpleContainer) {
|
||||
auto model = std::make_shared<SimpleContainer>();
|
||||
auto l1 = model->add(Linear(10, 3), "l1");
|
||||
auto l2 = model->add(Linear(3, 5), "l2");
|
||||
auto l3 = model->add(Linear(5, 100), "l3");
|
||||
|
||||
auto x = torch::randn({1000, 10}, torch::requires_grad());
|
||||
x = l1->forward(x).clamp_min(0);
|
||||
x = l2->forward(x).clamp_min(0);
|
||||
x = l3->forward(x).clamp_min(0);
|
||||
|
||||
x.backward();
|
||||
ASSERT_EQ(x.ndimension(), 2);
|
||||
ASSERT_EQ(x.size(0), 1000);
|
||||
ASSERT_EQ(x.size(1), 100);
|
||||
ASSERT_EQ(x.min().toCFloat(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, EmbeddingBasic) {
|
||||
const int64_t dict_size = 10;
|
||||
Embedding model(dict_size, 2);
|
||||
ASSERT_TRUE(model->parameters().contains("weight"));
|
||||
ASSERT_EQ(model->weight.ndimension(), 2);
|
||||
ASSERT_EQ(model->weight.size(0), dict_size);
|
||||
ASSERT_EQ(model->weight.size(1), 2);
|
||||
|
||||
// Cannot get gradients to change indices (input) - only for embedding
|
||||
// params
|
||||
auto x = torch::full({10}, dict_size - 1, torch::kInt64);
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(y.ndimension(), 2);
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(y.size(0), 10);
|
||||
ASSERT_EQ(y.size(1), 2);
|
||||
|
||||
ASSERT_EQ(model->parameters()["weight"].grad().numel(), 2 * dict_size);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, EmbeddingList) {
|
||||
Embedding model(6, 4);
|
||||
auto x = torch::full({2, 3}, 5, torch::kInt64);
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_EQ(y.size(0), 2);
|
||||
ASSERT_EQ(y.size(1), 3);
|
||||
ASSERT_EQ(y.size(2), 4);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Dropout) {
|
||||
Dropout dropout(0.5);
|
||||
torch::Tensor x = torch::ones(100, torch::requires_grad());
|
||||
torch::Tensor y = dropout->forward(x);
|
||||
|
||||
y.backward();
|
||||
ASSERT_EQ(y.ndimension(), 1);
|
||||
ASSERT_EQ(y.size(0), 100);
|
||||
ASSERT_LT(y.sum().toCFloat(), 130); // Probably
|
||||
ASSERT_GT(y.sum().toCFloat(), 70); // Probably
|
||||
|
||||
dropout->eval();
|
||||
y = dropout->forward(x);
|
||||
ASSERT_EQ(y.sum().toCFloat(), 100);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Parameters) {
|
||||
auto model = std::make_shared<NestedModel>();
|
||||
auto parameters = model->parameters();
|
||||
ASSERT_EQ(parameters["param"].size(0), 3);
|
||||
ASSERT_EQ(parameters["param"].size(1), 2);
|
||||
ASSERT_EQ(parameters["param"].size(2), 21);
|
||||
ASSERT_EQ(parameters["l1.bias"].size(0), 20);
|
||||
ASSERT_EQ(parameters["l1.weight"].size(0), 20);
|
||||
ASSERT_EQ(parameters["l1.weight"].size(1), 5);
|
||||
ASSERT_EQ(parameters["test.l1.bias"].size(0), 3);
|
||||
ASSERT_EQ(parameters["test.l1.weight"].size(0), 3);
|
||||
ASSERT_EQ(parameters["test.l1.weight"].size(1), 10);
|
||||
ASSERT_EQ(parameters["test.l2.bias"].size(0), 5);
|
||||
ASSERT_EQ(parameters["test.l2.weight"].size(0), 5);
|
||||
ASSERT_EQ(parameters["test.l2.weight"].size(1), 3);
|
||||
ASSERT_EQ(parameters["test.l3.bias"].size(0), 100);
|
||||
ASSERT_EQ(parameters["test.l3.weight"].size(0), 100);
|
||||
ASSERT_EQ(parameters["test.l3.weight"].size(1), 5);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, FunctionalCallsSuppliedFunction) {
|
||||
bool was_called = false;
|
||||
auto functional = Functional([&was_called](torch::Tensor input) {
|
||||
was_called = true;
|
||||
return input;
|
||||
});
|
||||
auto output = functional->forward(torch::ones(5, torch::requires_grad()));
|
||||
ASSERT_TRUE(was_called);
|
||||
ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad())));
|
||||
|
||||
was_called = false;
|
||||
// Use the call operator overload here.
|
||||
output = functional(torch::ones(5, torch::requires_grad()));
|
||||
ASSERT_TRUE(was_called);
|
||||
ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad())));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, FunctionalWithTorchFunction) {
|
||||
auto functional = Functional(torch::relu);
|
||||
ASSERT_EQ(functional(torch::ones({})).toCFloat(), 1);
|
||||
ASSERT_EQ(functional(torch::ones({})).toCFloat(), 1);
|
||||
ASSERT_EQ(functional(torch::ones({}) * -1).toCFloat(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, FunctionalArgumentBinding) {
|
||||
auto functional =
|
||||
Functional(torch::elu, /*alpha=*/1, /*scale=*/0, /*input_scale=*/1);
|
||||
ASSERT_EQ(functional(torch::ones({})).toCFloat(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, BatchNormStateful) {
|
||||
BatchNorm bn(5);
|
||||
|
||||
// Is stateful by default.
|
||||
ASSERT_TRUE(bn->options.stateful());
|
||||
|
||||
ASSERT_TRUE(bn->running_mean.defined());
|
||||
ASSERT_EQ(bn->running_mean.dim(), 1);
|
||||
ASSERT_EQ(bn->running_mean.size(0), 5);
|
||||
|
||||
ASSERT_TRUE(bn->running_variance.defined());
|
||||
ASSERT_EQ(bn->running_variance.dim(), 1);
|
||||
ASSERT_EQ(bn->running_variance.size(0), 5);
|
||||
|
||||
// Is affine by default.
|
||||
ASSERT_TRUE(bn->options.affine());
|
||||
|
||||
ASSERT_TRUE(bn->weight.defined());
|
||||
ASSERT_EQ(bn->weight.dim(), 1);
|
||||
ASSERT_EQ(bn->weight.size(0), 5);
|
||||
|
||||
ASSERT_TRUE(bn->bias.defined());
|
||||
ASSERT_EQ(bn->bias.dim(), 1);
|
||||
ASSERT_EQ(bn->bias.size(0), 5);
|
||||
}
|
||||
TEST_F(ModulesTest, BatchNormStateless) {
|
||||
BatchNorm bn(BatchNormOptions(5).stateful(false).affine(false));
|
||||
|
||||
ASSERT_FALSE(bn->running_mean.defined());
|
||||
ASSERT_FALSE(bn->running_variance.defined());
|
||||
ASSERT_FALSE(bn->weight.defined());
|
||||
ASSERT_FALSE(bn->bias.defined());
|
||||
|
||||
ASSERT_THROWS_WITH(
|
||||
bn->forward(torch::ones({2, 5})),
|
||||
"Calling BatchNorm::forward is only permitted "
|
||||
"when the 'stateful' option is true (was false). "
|
||||
"Use BatchNorm::pure_forward instead.");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, BatchNormPureForward) {
|
||||
BatchNorm bn(BatchNormOptions(5).affine(false));
|
||||
bn->eval();
|
||||
|
||||
// Want to make sure we use the supplied values in `pure_forward` even if
|
||||
// we are stateful.
|
||||
auto input = torch::randn({2, 5});
|
||||
auto mean = torch::randn(5);
|
||||
auto variance = torch::rand(5);
|
||||
auto output = bn->pure_forward(input, mean, variance);
|
||||
auto expected = (input - mean) / torch::sqrt(variance + bn->options.eps());
|
||||
ASSERT_TRUE(output.allclose(expected));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Linear_CUDA) {
|
||||
Linear model(5, 2);
|
||||
model->to(torch::kCUDA);
|
||||
auto x =
|
||||
torch::randn({10, 5}, torch::device(torch::kCUDA).requires_grad(true));
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(y.ndimension(), 2);
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(y.size(0), 10);
|
||||
ASSERT_EQ(y.size(1), 2);
|
||||
|
||||
ASSERT_EQ(model->parameters()["weight"].grad().numel(), 2 * 5);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Linear2_CUDA) {
|
||||
Linear model(5, 2);
|
||||
model->to(torch::kCUDA);
|
||||
model->to(torch::kCPU);
|
||||
auto x = torch::randn({10, 5}, torch::requires_grad());
|
||||
auto y = model->forward(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(y.ndimension(), 2);
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
ASSERT_EQ(y.size(0), 10);
|
||||
ASSERT_EQ(y.size(1), 2);
|
||||
|
||||
ASSERT_EQ(model->parameters()["weight"].grad().numel(), 2 * 5);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user