mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Use CATCH prefix to avoid name conflicts with Caffe2.
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11780 Differential Revision: D9889925 Pulled By: gchanan fbshipit-source-id: 5eca849c36ced00b8ae7482b7945b445a3e1687e
This commit is contained in:
committed by
Facebook Github Bot
parent
4ee0a78ee6
commit
e00fb69b25
@ -1,4 +1,4 @@
|
||||
#include <catch.hpp>
|
||||
#include "catch_utils.hpp"
|
||||
|
||||
#include <torch/nn/modules/linear.h>
|
||||
#include <torch/nn/modules/rnn.h>
|
||||
@ -71,22 +71,22 @@ void check_lstm_sizes(RNNOutput output) {
|
||||
// Expect the LSTM to have 64 outputs and 3 layers, with an input of batch
|
||||
// 10 and 16 time steps (10 x 16 x n)
|
||||
|
||||
REQUIRE(output.output.ndimension() == 3);
|
||||
REQUIRE(output.output.size(0) == 10);
|
||||
REQUIRE(output.output.size(1) == 16);
|
||||
REQUIRE(output.output.size(2) == 64);
|
||||
CATCH_REQUIRE(output.output.ndimension() == 3);
|
||||
CATCH_REQUIRE(output.output.size(0) == 10);
|
||||
CATCH_REQUIRE(output.output.size(1) == 16);
|
||||
CATCH_REQUIRE(output.output.size(2) == 64);
|
||||
|
||||
REQUIRE(output.state.ndimension() == 4);
|
||||
REQUIRE(output.state.size(0) == 2); // (hx, cx)
|
||||
REQUIRE(output.state.size(1) == 3); // layers
|
||||
REQUIRE(output.state.size(2) == 16); // Batchsize
|
||||
REQUIRE(output.state.size(3) == 64); // 64 hidden dims
|
||||
CATCH_REQUIRE(output.state.ndimension() == 4);
|
||||
CATCH_REQUIRE(output.state.size(0) == 2); // (hx, cx)
|
||||
CATCH_REQUIRE(output.state.size(1) == 3); // layers
|
||||
CATCH_REQUIRE(output.state.size(2) == 16); // Batchsize
|
||||
CATCH_REQUIRE(output.state.size(3) == 64); // 64 hidden dims
|
||||
|
||||
// Something is in the hiddens
|
||||
REQUIRE(output.state.norm().toCFloat() > 0);
|
||||
CATCH_REQUIRE(output.state.norm().toCFloat() > 0);
|
||||
}
|
||||
|
||||
TEST_CASE("RNN/CheckOutputSizes") {
|
||||
CATCH_TEST_CASE("RNN/CheckOutputSizes") {
|
||||
torch::manual_seed(0);
|
||||
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
|
||||
// Input size is: sequence length, batch size, input size
|
||||
@ -104,10 +104,10 @@ TEST_CASE("RNN/CheckOutputSizes") {
|
||||
torch::Tensor diff = next.state - output.state;
|
||||
|
||||
// Hiddens changed
|
||||
REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
|
||||
CATCH_REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
|
||||
}
|
||||
|
||||
TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
|
||||
CATCH_TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
|
||||
torch::manual_seed(0);
|
||||
// Make sure the outputs match pytorch outputs
|
||||
LSTM model(2, 2);
|
||||
@ -127,10 +127,10 @@ TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
|
||||
}
|
||||
|
||||
auto out = model->forward(x);
|
||||
REQUIRE(out.output.ndimension() == 3);
|
||||
REQUIRE(out.output.size(0) == 3);
|
||||
REQUIRE(out.output.size(1) == 4);
|
||||
REQUIRE(out.output.size(2) == 2);
|
||||
CATCH_REQUIRE(out.output.ndimension() == 3);
|
||||
CATCH_REQUIRE(out.output.size(0) == 3);
|
||||
CATCH_REQUIRE(out.output.size(1) == 4);
|
||||
CATCH_REQUIRE(out.output.size(2) == 2);
|
||||
|
||||
auto flat = out.output.view(3 * 4 * 2);
|
||||
float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239,
|
||||
@ -138,14 +138,14 @@ TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
|
||||
0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
|
||||
0.7769, 0.8905, 0.7635, 0.8794, 0.7484, 0.8666};
|
||||
for (size_t i = 0; i < 3 * 4 * 2; i++) {
|
||||
REQUIRE(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3);
|
||||
CATCH_REQUIRE(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3);
|
||||
}
|
||||
|
||||
REQUIRE(out.state.ndimension() == 4); // (hx, cx) x layers x B x 2
|
||||
REQUIRE(out.state.size(0) == 2);
|
||||
REQUIRE(out.state.size(1) == 1);
|
||||
REQUIRE(out.state.size(2) == 4);
|
||||
REQUIRE(out.state.size(3) == 2);
|
||||
CATCH_REQUIRE(out.state.ndimension() == 4); // (hx, cx) x layers x B x 2
|
||||
CATCH_REQUIRE(out.state.size(0) == 2);
|
||||
CATCH_REQUIRE(out.state.size(1) == 1);
|
||||
CATCH_REQUIRE(out.state.size(2) == 4);
|
||||
CATCH_REQUIRE(out.state.size(3) == 2);
|
||||
flat = out.state.view(16);
|
||||
float h_out[] = {0.7889,
|
||||
0.9003,
|
||||
@ -164,33 +164,33 @@ TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
|
||||
1.0931,
|
||||
1.4911};
|
||||
for (size_t i = 0; i < 16; i++) {
|
||||
REQUIRE(std::abs(flat[i].toCFloat() - h_out[i]) < 1e-3);
|
||||
CATCH_REQUIRE(std::abs(flat[i].toCFloat() - h_out[i]) < 1e-3);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("RNN/integration/LSTM") {
|
||||
REQUIRE(test_RNN_xor<LSTM>(
|
||||
CATCH_TEST_CASE("RNN/integration/LSTM") {
|
||||
CATCH_REQUIRE(test_RNN_xor<LSTM>(
|
||||
[](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }));
|
||||
}
|
||||
|
||||
TEST_CASE("RNN/integration/GRU") {
|
||||
REQUIRE(
|
||||
CATCH_TEST_CASE("RNN/integration/GRU") {
|
||||
CATCH_REQUIRE(
|
||||
test_RNN_xor<GRU>([](int s) { return GRU(GRUOptions(s, s).layers(2)); }));
|
||||
}
|
||||
|
||||
TEST_CASE("RNN/integration/RNN") {
|
||||
SECTION("relu") {
|
||||
REQUIRE(test_RNN_xor<RNN>(
|
||||
CATCH_TEST_CASE("RNN/integration/RNN") {
|
||||
CATCH_SECTION("relu") {
|
||||
CATCH_REQUIRE(test_RNN_xor<RNN>(
|
||||
[](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }));
|
||||
}
|
||||
SECTION("tanh") {
|
||||
REQUIRE(test_RNN_xor<RNN>(
|
||||
CATCH_SECTION("tanh") {
|
||||
CATCH_REQUIRE(test_RNN_xor<RNN>(
|
||||
[](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("rnn_cuda", "[cuda]") {
|
||||
SECTION("sizes") {
|
||||
CATCH_TEST_CASE("rnn_cuda", "[cuda]") {
|
||||
CATCH_SECTION("sizes") {
|
||||
torch::manual_seed(0);
|
||||
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
|
||||
model->to(torch::kCUDA);
|
||||
@ -209,26 +209,26 @@ TEST_CASE("rnn_cuda", "[cuda]") {
|
||||
torch::Tensor diff = next.state - output.state;
|
||||
|
||||
// Hiddens changed
|
||||
REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
|
||||
CATCH_REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
|
||||
}
|
||||
|
||||
SECTION("lstm") {
|
||||
REQUIRE(test_RNN_xor<LSTM>(
|
||||
CATCH_SECTION("lstm") {
|
||||
CATCH_REQUIRE(test_RNN_xor<LSTM>(
|
||||
[](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }, true));
|
||||
}
|
||||
|
||||
SECTION("gru") {
|
||||
REQUIRE(test_RNN_xor<GRU>(
|
||||
CATCH_SECTION("gru") {
|
||||
CATCH_REQUIRE(test_RNN_xor<GRU>(
|
||||
[](int s) { return GRU(GRUOptions(s, s).layers(2)); }, true));
|
||||
}
|
||||
|
||||
SECTION("rnn") {
|
||||
SECTION("relu") {
|
||||
REQUIRE(test_RNN_xor<RNN>(
|
||||
CATCH_SECTION("rnn") {
|
||||
CATCH_SECTION("relu") {
|
||||
CATCH_REQUIRE(test_RNN_xor<RNN>(
|
||||
[](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }, true));
|
||||
}
|
||||
SECTION("tanh") {
|
||||
REQUIRE(test_RNN_xor<RNN>(
|
||||
CATCH_SECTION("tanh") {
|
||||
CATCH_REQUIRE(test_RNN_xor<RNN>(
|
||||
[](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }, true));
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user