mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Rewrite C++ API tests in gtest (#11953)
Summary: This PR is a large codemod to rewrite all C++ API tests with GoogleTest (gtest) instead of Catch. You can largely trust me to have correctly code-modded the tests, so it's not required to review every of the 2000+ changed lines. However, additional things I changed were: 1. Moved the cmake parts for these tests into their own `CMakeLists.txt` under `test/cpp/api` and calling `add_subdirectory` from `torch/CMakeLists.txt` 2. Fixing DataParallel tests which weren't being compiled because `USE_CUDA` wasn't correctly being set at all. 3. Updated README ezyang ebetica Pull Request resolved: https://github.com/pytorch/pytorch/pull/11953 Differential Revision: D9998883 Pulled By: goldsborough fbshipit-source-id: affe3f320b0ca63e7e0019926a59076bb943db80
This commit is contained in:
committed by
Facebook Github Bot
parent
d0db23e95a
commit
825181ea9d
@ -1,4 +1,4 @@
|
||||
#include "catch_utils.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/nn/modules/functional.h>
|
||||
@ -9,7 +9,7 @@
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <test/cpp/api/optim_baseline.h>
|
||||
#include <test/cpp/api/util.h>
|
||||
#include <test/cpp/api/support.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
@ -118,24 +118,24 @@ void check_exact_values(
|
||||
optimizer.step();
|
||||
|
||||
if (i % kSampleEvery == 0) {
|
||||
CATCH_REQUIRE(
|
||||
ASSERT_TRUE(
|
||||
expected_parameters.at(i / kSampleEvery).size() == parameters.size());
|
||||
for (size_t p = 0; p < parameters.size(); ++p) {
|
||||
CATCH_REQUIRE(parameters.at(p)->defined());
|
||||
ASSERT_TRUE(parameters.at(p)->defined());
|
||||
auto computed = parameters.at(p)->flatten();
|
||||
auto expected = expected_parameters.at(i / kSampleEvery).at(p);
|
||||
if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
|
||||
std::cout << "Iteration " << i << ": " << computed
|
||||
<< " != " << expected << " (parameter " << p << ")"
|
||||
<< std::endl;
|
||||
CATCH_REQUIRE(false);
|
||||
ASSERT_TRUE(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/BasicInterface") {
|
||||
TEST(OptimTest, BasicInterface) {
|
||||
struct MyOptimizer : Optimizer {
|
||||
using Optimizer::Optimizer;
|
||||
void step() override {}
|
||||
@ -144,139 +144,140 @@ CATCH_TEST_CASE("Optim/BasicInterface") {
|
||||
torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
|
||||
{
|
||||
MyOptimizer optimizer(parameters);
|
||||
CATCH_REQUIRE(optimizer.size() == parameters.size());
|
||||
ASSERT_EQ(optimizer.size(), parameters.size());
|
||||
}
|
||||
{
|
||||
MyOptimizer optimizer;
|
||||
CATCH_REQUIRE(optimizer.size() == 0);
|
||||
ASSERT_EQ(optimizer.size(), 0);
|
||||
optimizer.add_parameters(parameters);
|
||||
CATCH_REQUIRE(optimizer.size() == parameters.size());
|
||||
ASSERT_EQ(optimizer.size(), parameters.size());
|
||||
for (size_t p = 0; p < parameters.size(); ++p) {
|
||||
CATCH_REQUIRE(optimizer.parameters()[p].allclose(parameters[p]));
|
||||
ASSERT_TRUE(optimizer.parameters()[p].allclose(parameters[p]));
|
||||
}
|
||||
}
|
||||
{
|
||||
Linear linear(3, 4);
|
||||
MyOptimizer optimizer(linear->parameters());
|
||||
CATCH_REQUIRE(optimizer.size() == linear->parameters().size());
|
||||
ASSERT_EQ(optimizer.size(), linear->parameters().size());
|
||||
}
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/XORConvergence/SGD") {
|
||||
CATCH_REQUIRE(test_optimizer_xor<SGD>(
|
||||
TEST(OptimTest, XORConvergence_SGD) {
|
||||
ASSERT_TRUE(test_optimizer_xor<SGD>(
|
||||
SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6)));
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/XORConvergence/Adagrad") {
|
||||
CATCH_REQUIRE(test_optimizer_xor<Adagrad>(
|
||||
TEST(OptimTest, XORConvergence_Adagrad) {
|
||||
ASSERT_TRUE(test_optimizer_xor<Adagrad>(
|
||||
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3)));
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/XORConvergence/RMSprop") {
|
||||
CATCH_REQUIRE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
|
||||
TEST(OptimTest, XORConvergence_RMSprop) {
|
||||
ASSERT_TRUE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/XORConvergence/RMSpropWithMomentum") {
|
||||
CATCH_REQUIRE(test_optimizer_xor<RMSprop>(
|
||||
TEST(OptimTest, XORConvergence_RMSpropWithMomentum) {
|
||||
ASSERT_TRUE(test_optimizer_xor<RMSprop>(
|
||||
RMSpropOptions(0.1).momentum(0.9).weight_decay(1e-6)));
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/XORConvergence/Adam") {
|
||||
CATCH_REQUIRE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
|
||||
TEST(OptimTest, XORConvergence_Adam) {
|
||||
ASSERT_TRUE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/XORConvergence/AdamWithAmsgrad") {
|
||||
CATCH_REQUIRE(test_optimizer_xor<Adam>(
|
||||
TEST(OptimTest, XORConvergence_AdamWithAmsgrad) {
|
||||
ASSERT_TRUE(test_optimizer_xor<Adam>(
|
||||
AdamOptions(0.1).weight_decay(1e-6).amsgrad(true)));
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ProducesPyTorchValues/Adam") {
|
||||
TEST(OptimTest, ProducesPyTorchValues_Adam) {
|
||||
check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ProducesPyTorchValues/AdamWithWeightDecay") {
|
||||
TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecay) {
|
||||
check_exact_values<Adam>(
|
||||
AdamOptions(1.0).weight_decay(1e-2),
|
||||
expected_parameters::Adam_with_weight_decay);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ProducesPyTorchValues/AdamWithWeightDecayAndAMSGrad") {
|
||||
TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) {
|
||||
check_exact_values<Adam>(
|
||||
AdamOptions(1.0).weight_decay(1e-6).amsgrad(true),
|
||||
expected_parameters::Adam_with_weight_decay_and_amsgrad);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ProducesPyTorchValues/Adagrad") {
|
||||
TEST(OptimTest, ProducesPyTorchValues_Adagrad) {
|
||||
check_exact_values<Adagrad>(
|
||||
AdagradOptions(1.0), expected_parameters::Adagrad);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ProducesPyTorchValues/AdagradWithWeightDecay") {
|
||||
TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecay) {
|
||||
check_exact_values<Adagrad>(
|
||||
AdagradOptions(1.0).weight_decay(1e-2),
|
||||
expected_parameters::Adagrad_with_weight_decay);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ProducesPyTorchValues/AdagradWithWeightDecayAndLRDecay") {
|
||||
TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay) {
|
||||
check_exact_values<Adagrad>(
|
||||
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3),
|
||||
expected_parameters::Adagrad_with_weight_decay_and_lr_decay);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ProducesPyTorchValues/RMSprop") {
|
||||
TEST(OptimTest, ProducesPyTorchValues_RMSprop) {
|
||||
check_exact_values<RMSprop>(
|
||||
RMSpropOptions(0.1), expected_parameters::RMSprop);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ProducesPyTorchValues/RMSpropWithWeightDecay") {
|
||||
TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecay) {
|
||||
check_exact_values<RMSprop>(
|
||||
RMSpropOptions(0.1).weight_decay(1e-2),
|
||||
expected_parameters::RMSprop_with_weight_decay);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ProducesPyTorchValues/RMSpropWithWeightDecayAndCentered") {
|
||||
TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered) {
|
||||
check_exact_values<RMSprop>(
|
||||
RMSpropOptions(0.1).weight_decay(1e-6).centered(true),
|
||||
expected_parameters::RMSprop_with_weight_decay_and_centered);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE(
|
||||
"Optim/ProducesPyTorchValues/RMSpropWithWeightDecayAndCenteredAndMomentum") {
|
||||
TEST(
|
||||
OptimTest,
|
||||
ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum) {
|
||||
check_exact_values<RMSprop>(
|
||||
RMSpropOptions(0.1).weight_decay(1e-6).centered(true).momentum(0.9),
|
||||
expected_parameters::RMSprop_with_weight_decay_and_centered_and_momentum);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ProducesPyTorchValues/SGD") {
|
||||
TEST(OptimTest, ProducesPyTorchValues_SGD) {
|
||||
check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecay") {
|
||||
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecay) {
|
||||
check_exact_values<SGD>(
|
||||
SGDOptions(0.1).weight_decay(1e-2),
|
||||
expected_parameters::SGD_with_weight_decay);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecayAndMomentum") {
|
||||
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndMomentum) {
|
||||
check_exact_values<SGD>(
|
||||
SGDOptions(0.1).weight_decay(1e-2).momentum(0.9),
|
||||
expected_parameters::SGD_with_weight_decay_and_momentum);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecayAndNesterovMomentum") {
|
||||
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum) {
|
||||
check_exact_values<SGD>(
|
||||
SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(true),
|
||||
expected_parameters::SGD_with_weight_decay_and_nesterov_momentum);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ZeroGrad") {
|
||||
TEST(OptimTest, ZeroGrad) {
|
||||
torch::manual_seed(0);
|
||||
|
||||
Linear model(2, 8);
|
||||
SGD optimizer(model->parameters(), 0.1);
|
||||
|
||||
for (const auto& parameter : model->parameters()) {
|
||||
CATCH_REQUIRE(!parameter->grad().defined());
|
||||
ASSERT_FALSE(parameter->grad().defined());
|
||||
}
|
||||
|
||||
auto output = model->forward(torch::ones({5, 2}));
|
||||
@ -284,19 +285,19 @@ CATCH_TEST_CASE("Optim/ZeroGrad") {
|
||||
loss.backward();
|
||||
|
||||
for (const auto& parameter : model->parameters()) {
|
||||
CATCH_REQUIRE(parameter->grad().defined());
|
||||
CATCH_REQUIRE(parameter->grad().sum().toCFloat() > 0);
|
||||
ASSERT_TRUE(parameter->grad().defined());
|
||||
ASSERT_GT(parameter->grad().sum().toCFloat(), 0);
|
||||
}
|
||||
|
||||
optimizer.zero_grad();
|
||||
|
||||
for (const auto& parameter : model->parameters()) {
|
||||
CATCH_REQUIRE(parameter->grad().defined());
|
||||
CATCH_REQUIRE(parameter->grad().sum().toCFloat() == 0);
|
||||
ASSERT_TRUE(parameter->grad().defined());
|
||||
ASSERT_EQ(parameter->grad().sum().toCFloat(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/ExternalVectorOfParameters") {
|
||||
TEST(OptimTest, ExternalVectorOfParameters) {
|
||||
torch::manual_seed(0);
|
||||
|
||||
std::vector<torch::Tensor> parameters = {
|
||||
@ -313,12 +314,12 @@ CATCH_TEST_CASE("Optim/ExternalVectorOfParameters") {
|
||||
|
||||
optimizer.step();
|
||||
|
||||
CATCH_REQUIRE(parameters[0].allclose(original_parameters[0] - 1.0));
|
||||
CATCH_REQUIRE(parameters[1].allclose(original_parameters[1] - 1.0));
|
||||
CATCH_REQUIRE(parameters[2].allclose(original_parameters[2] - 1.0));
|
||||
ASSERT_TRUE(parameters[0].allclose(original_parameters[0] - 1.0));
|
||||
ASSERT_TRUE(parameters[1].allclose(original_parameters[1] - 1.0));
|
||||
ASSERT_TRUE(parameters[2].allclose(original_parameters[2] - 1.0));
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Optim/AddParameter/LBFGS") {
|
||||
TEST(OptimTest, AddParameter_LBFGS) {
|
||||
torch::manual_seed(0);
|
||||
|
||||
std::vector<torch::Tensor> parameters = {torch::randn({5, 5})};
|
||||
|
Reference in New Issue
Block a user