Use CATCH prefix to avoid name conflicts with Caffe2.

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11780

Differential Revision: D9889925

Pulled By: gchanan

fbshipit-source-id: 5eca849c36ced00b8ae7482b7945b445a3e1687e
This commit is contained in:
Gregory Chanan
2018-09-18 07:59:41 -07:00
committed by Facebook Github Bot
parent 4ee0a78ee6
commit e00fb69b25
41 changed files with 1689 additions and 1665 deletions

View File

@ -1,4 +1,4 @@
#include <catch.hpp>
#include "catch_utils.hpp"
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/rnn.h>
@ -71,22 +71,22 @@ void check_lstm_sizes(RNNOutput output) {
// Expect the LSTM to have 64 outputs and 3 layers, with an input of batch
// 10 and 16 time steps (10 x 16 x n)
REQUIRE(output.output.ndimension() == 3);
REQUIRE(output.output.size(0) == 10);
REQUIRE(output.output.size(1) == 16);
REQUIRE(output.output.size(2) == 64);
CATCH_REQUIRE(output.output.ndimension() == 3);
CATCH_REQUIRE(output.output.size(0) == 10);
CATCH_REQUIRE(output.output.size(1) == 16);
CATCH_REQUIRE(output.output.size(2) == 64);
REQUIRE(output.state.ndimension() == 4);
REQUIRE(output.state.size(0) == 2); // (hx, cx)
REQUIRE(output.state.size(1) == 3); // layers
REQUIRE(output.state.size(2) == 16); // Batchsize
REQUIRE(output.state.size(3) == 64); // 64 hidden dims
CATCH_REQUIRE(output.state.ndimension() == 4);
CATCH_REQUIRE(output.state.size(0) == 2); // (hx, cx)
CATCH_REQUIRE(output.state.size(1) == 3); // layers
CATCH_REQUIRE(output.state.size(2) == 16); // Batchsize
CATCH_REQUIRE(output.state.size(3) == 64); // 64 hidden dims
// Something is in the hiddens
REQUIRE(output.state.norm().toCFloat() > 0);
CATCH_REQUIRE(output.state.norm().toCFloat() > 0);
}
TEST_CASE("RNN/CheckOutputSizes") {
CATCH_TEST_CASE("RNN/CheckOutputSizes") {
torch::manual_seed(0);
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
// Input size is: sequence length, batch size, input size
@ -104,10 +104,10 @@ TEST_CASE("RNN/CheckOutputSizes") {
torch::Tensor diff = next.state - output.state;
// Hiddens changed
REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
CATCH_REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
}
TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
CATCH_TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
torch::manual_seed(0);
// Make sure the outputs match pytorch outputs
LSTM model(2, 2);
@ -127,10 +127,10 @@ TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
}
auto out = model->forward(x);
REQUIRE(out.output.ndimension() == 3);
REQUIRE(out.output.size(0) == 3);
REQUIRE(out.output.size(1) == 4);
REQUIRE(out.output.size(2) == 2);
CATCH_REQUIRE(out.output.ndimension() == 3);
CATCH_REQUIRE(out.output.size(0) == 3);
CATCH_REQUIRE(out.output.size(1) == 4);
CATCH_REQUIRE(out.output.size(2) == 2);
auto flat = out.output.view(3 * 4 * 2);
float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239,
@ -138,14 +138,14 @@ TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
0.7769, 0.8905, 0.7635, 0.8794, 0.7484, 0.8666};
for (size_t i = 0; i < 3 * 4 * 2; i++) {
REQUIRE(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3);
CATCH_REQUIRE(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3);
}
REQUIRE(out.state.ndimension() == 4); // (hx, cx) x layers x B x 2
REQUIRE(out.state.size(0) == 2);
REQUIRE(out.state.size(1) == 1);
REQUIRE(out.state.size(2) == 4);
REQUIRE(out.state.size(3) == 2);
CATCH_REQUIRE(out.state.ndimension() == 4); // (hx, cx) x layers x B x 2
CATCH_REQUIRE(out.state.size(0) == 2);
CATCH_REQUIRE(out.state.size(1) == 1);
CATCH_REQUIRE(out.state.size(2) == 4);
CATCH_REQUIRE(out.state.size(3) == 2);
flat = out.state.view(16);
float h_out[] = {0.7889,
0.9003,
@ -164,33 +164,33 @@ TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
1.0931,
1.4911};
for (size_t i = 0; i < 16; i++) {
REQUIRE(std::abs(flat[i].toCFloat() - h_out[i]) < 1e-3);
CATCH_REQUIRE(std::abs(flat[i].toCFloat() - h_out[i]) < 1e-3);
}
}
TEST_CASE("RNN/integration/LSTM") {
REQUIRE(test_RNN_xor<LSTM>(
CATCH_TEST_CASE("RNN/integration/LSTM") {
CATCH_REQUIRE(test_RNN_xor<LSTM>(
[](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }));
}
TEST_CASE("RNN/integration/GRU") {
REQUIRE(
CATCH_TEST_CASE("RNN/integration/GRU") {
CATCH_REQUIRE(
test_RNN_xor<GRU>([](int s) { return GRU(GRUOptions(s, s).layers(2)); }));
}
TEST_CASE("RNN/integration/RNN") {
SECTION("relu") {
REQUIRE(test_RNN_xor<RNN>(
CATCH_TEST_CASE("RNN/integration/RNN") {
CATCH_SECTION("relu") {
CATCH_REQUIRE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }));
}
SECTION("tanh") {
REQUIRE(test_RNN_xor<RNN>(
CATCH_SECTION("tanh") {
CATCH_REQUIRE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }));
}
}
TEST_CASE("rnn_cuda", "[cuda]") {
SECTION("sizes") {
CATCH_TEST_CASE("rnn_cuda", "[cuda]") {
CATCH_SECTION("sizes") {
torch::manual_seed(0);
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
model->to(torch::kCUDA);
@ -209,26 +209,26 @@ TEST_CASE("rnn_cuda", "[cuda]") {
torch::Tensor diff = next.state - output.state;
// Hiddens changed
REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
CATCH_REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
}
SECTION("lstm") {
REQUIRE(test_RNN_xor<LSTM>(
CATCH_SECTION("lstm") {
CATCH_REQUIRE(test_RNN_xor<LSTM>(
[](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }, true));
}
SECTION("gru") {
REQUIRE(test_RNN_xor<GRU>(
CATCH_SECTION("gru") {
CATCH_REQUIRE(test_RNN_xor<GRU>(
[](int s) { return GRU(GRUOptions(s, s).layers(2)); }, true));
}
SECTION("rnn") {
SECTION("relu") {
REQUIRE(test_RNN_xor<RNN>(
CATCH_SECTION("rnn") {
CATCH_SECTION("relu") {
CATCH_REQUIRE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }, true));
}
SECTION("tanh") {
REQUIRE(test_RNN_xor<RNN>(
CATCH_SECTION("tanh") {
CATCH_REQUIRE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }, true));
}
}