Revert D20311699: [pytorch][PR] [C++ API] RNN / GRU / LSTM layer refactoring

Test Plan: revert-hammer

Differential Revision:
D20311699

Original commit changeset: e2b60fc7bac6

fbshipit-source-id: 72f4a762189490998d6b716857eeac053a11742d
This commit is contained in:
Will Feng
2020-03-14 16:15:47 -07:00
committed by Facebook GitHub Bot
parent 84bd71dbd4
commit 6c555e1508
11 changed files with 455 additions and 947 deletions

View File

@ -43,11 +43,7 @@ TEST(EnumTest, AllEnums) {
torch::enumtype::kBatchMean,
torch::enumtype::kZeros,
torch::enumtype::kBorder,
torch::enumtype::kReflection,
torch::enumtype::kRNN_TANH,
torch::enumtype::kRNN_RELU,
torch::enumtype::kLSTM,
torch::enumtype::kGRU
torch::enumtype::kReflection
> v;
TORCH_ENUM_PRETTY_PRINT_TEST(Linear)
@ -80,8 +76,4 @@ TEST(EnumTest, AllEnums) {
TORCH_ENUM_PRETTY_PRINT_TEST(Zeros)
TORCH_ENUM_PRETTY_PRINT_TEST(Border)
TORCH_ENUM_PRETTY_PRINT_TEST(Reflection)
TORCH_ENUM_PRETTY_PRINT_TEST(RNN_TANH)
TORCH_ENUM_PRETTY_PRINT_TEST(RNN_RELU)
TORCH_ENUM_PRETTY_PRINT_TEST(LSTM)
TORCH_ENUM_PRETTY_PRINT_TEST(GRU)
}

View File

@ -283,7 +283,7 @@ TEST_F(ModuleListTest, PrettyPrintModuleList) {
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
" (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
")");
}

View File

@ -23,7 +23,7 @@ bool test_RNN_xor(Func&& model_maker, bool cuda = false) {
auto B = x.size(1);
x = x.view({T * B, 1});
x = l1->forward(x).view({T, B, nhid}).tanh_();
x = std::get<0>(rnn->forward(x))[T - 1];
x = rnn->forward(x).output[T - 1];
x = lo->forward(x);
return x;
};
@ -61,39 +61,29 @@ bool test_RNN_xor(Func&& model_maker, bool cuda = false) {
return true;
};
void check_lstm_sizes(std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> lstm_output) {
void check_lstm_sizes(RNNOutput output) {
// Expect the LSTM to have 64 outputs and 3 layers, with an input of batch
// 10 and 16 time steps (10 x 16 x n)
torch::Tensor output = std::get<0>(lstm_output);
std::tuple<torch::Tensor, torch::Tensor> state = std::get<1>(lstm_output);
torch::Tensor hx = std::get<0>(state);
torch::Tensor cx = std::get<1>(state);
ASSERT_EQ(output.output.ndimension(), 3);
ASSERT_EQ(output.output.size(0), 10);
ASSERT_EQ(output.output.size(1), 16);
ASSERT_EQ(output.output.size(2), 64);
ASSERT_EQ(output.ndimension(), 3);
ASSERT_EQ(output.size(0), 10);
ASSERT_EQ(output.size(1), 16);
ASSERT_EQ(output.size(2), 64);
ASSERT_EQ(hx.ndimension(), 3);
ASSERT_EQ(hx.size(0), 3); // layers
ASSERT_EQ(hx.size(1), 16); // Batchsize
ASSERT_EQ(hx.size(2), 64); // 64 hidden dims
ASSERT_EQ(cx.ndimension(), 3);
ASSERT_EQ(cx.size(0), 3); // layers
ASSERT_EQ(cx.size(1), 16); // Batchsize
ASSERT_EQ(cx.size(2), 64); // 64 hidden dims
ASSERT_EQ(output.state.ndimension(), 4);
ASSERT_EQ(output.state.size(0), 2); // (hx, cx)
ASSERT_EQ(output.state.size(1), 3); // layers
ASSERT_EQ(output.state.size(2), 16); // Batchsize
ASSERT_EQ(output.state.size(3), 64); // 64 hidden dims
// Something is in the hiddens
ASSERT_GT(hx.norm().item<float>(), 0);
ASSERT_GT(cx.norm().item<float>(), 0);
ASSERT_GT(output.state.norm().item<float>(), 0);
}
struct RNNTest : torch::test::SeedingFixture {};
TEST_F(RNNTest, CheckOutputSizes) {
LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2));
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
// Input size is: sequence length, batch size, input size
auto x = torch::randn({10, 16, 128}, torch::requires_grad());
auto output = model->forward(x);
@ -102,17 +92,11 @@ TEST_F(RNNTest, CheckOutputSizes) {
y.backward();
check_lstm_sizes(output);
auto next = model->forward(x, std::get<1>(output));
auto next = model->forward(x, output.state);
check_lstm_sizes(next);
auto output_hx = std::get<0>(std::get<1>(output));
auto output_cx = std::get<1>(std::get<1>(output));
auto next_hx = std::get<0>(std::get<1>(next));
auto next_cx = std::get<1>(std::get<1>(next));
torch::Tensor diff = torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0);
torch::Tensor diff = next.state - output.state;
// Hiddens changed
ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
@ -138,12 +122,12 @@ TEST_F(RNNTest, CheckOutputValuesMatchPyTorch) {
}
auto out = model->forward(x);
ASSERT_EQ(std::get<0>(out).ndimension(), 3);
ASSERT_EQ(std::get<0>(out).size(0), 3);
ASSERT_EQ(std::get<0>(out).size(1), 4);
ASSERT_EQ(std::get<0>(out).size(2), 2);
ASSERT_EQ(out.output.ndimension(), 3);
ASSERT_EQ(out.output.size(0), 3);
ASSERT_EQ(out.output.size(1), 4);
ASSERT_EQ(out.output.size(2), 2);
auto flat = std::get<0>(out).view(3 * 4 * 2);
auto flat = out.output.view(3 * 4 * 2);
float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239,
0.4183, 0.5147, 0.6822, 0.8064, 0.6726, 0.7968,
0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
@ -152,20 +136,12 @@ TEST_F(RNNTest, CheckOutputValuesMatchPyTorch) {
ASSERT_LT(std::abs(flat[i].item<float>() - c_out[i]), 1e-3);
}
auto hx = std::get<0>(std::get<1>(out));
auto cx = std::get<1>(std::get<1>(out));
ASSERT_EQ(hx.ndimension(), 3); // layers x B x 2
ASSERT_EQ(hx.size(0), 1);
ASSERT_EQ(hx.size(1), 4);
ASSERT_EQ(hx.size(2), 2);
ASSERT_EQ(cx.ndimension(), 3); // layers x B x 2
ASSERT_EQ(cx.size(0), 1);
ASSERT_EQ(cx.size(1), 4);
ASSERT_EQ(cx.size(2), 2);
flat = torch::cat({hx, cx}, 0).view(16);
ASSERT_EQ(out.state.ndimension(), 4); // (hx, cx) x layers x B x 2
ASSERT_EQ(out.state.size(0), 2);
ASSERT_EQ(out.state.size(1), 1);
ASSERT_EQ(out.state.size(2), 4);
ASSERT_EQ(out.state.size(3), 2);
flat = out.state.view(16);
float h_out[] = {0.7889,
0.9003,
0.7769,
@ -189,27 +165,27 @@ TEST_F(RNNTest, CheckOutputValuesMatchPyTorch) {
TEST_F(RNNTest, EndToEndLSTM) {
ASSERT_TRUE(test_RNN_xor<LSTM>(
[](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); }));
[](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }));
}
TEST_F(RNNTest, EndToEndGRU) {
ASSERT_TRUE(
test_RNN_xor<GRU>([](int s) { return GRU(GRUOptions(s, s).num_layers(2)); }));
test_RNN_xor<GRU>([](int s) { return GRU(GRUOptions(s, s).layers(2)); }));
}
TEST_F(RNNTest, EndToEndRNNRelu) {
ASSERT_TRUE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2)); }));
[](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }));
}
TEST_F(RNNTest, EndToEndRNNTanh) {
ASSERT_TRUE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2)); }));
[](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }));
}
TEST_F(RNNTest, Sizes_CUDA) {
torch::manual_seed(0);
LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2));
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
model->to(torch::kCUDA);
auto x =
torch::randn({10, 16, 128}, torch::requires_grad().device(torch::kCUDA));
@ -219,17 +195,11 @@ TEST_F(RNNTest, Sizes_CUDA) {
y.backward();
check_lstm_sizes(output);
auto next = model->forward(x, std::get<1>(output));
auto next = model->forward(x, output.state);
check_lstm_sizes(next);
auto output_hx = std::get<0>(std::get<1>(output));
auto output_cx = std::get<1>(std::get<1>(output));
auto next_hx = std::get<0>(std::get<1>(next));
auto next_cx = std::get<1>(std::get<1>(next));
torch::Tensor diff = torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0);
torch::Tensor diff = next.state - output.state;
// Hiddens changed
ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
@ -237,68 +207,51 @@ TEST_F(RNNTest, Sizes_CUDA) {
TEST_F(RNNTest, EndToEndLSTM_CUDA) {
ASSERT_TRUE(test_RNN_xor<LSTM>(
[](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); }, true));
[](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }, true));
}
TEST_F(RNNTest, EndToEndGRU_CUDA) {
ASSERT_TRUE(test_RNN_xor<GRU>(
[](int s) { return GRU(GRUOptions(s, s).num_layers(2)); }, true));
[](int s) { return GRU(GRUOptions(s, s).layers(2)); }, true));
}
TEST_F(RNNTest, EndToEndRNNRelu_CUDA) {
ASSERT_TRUE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2)); }, true));
[](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }, true));
}
TEST_F(RNNTest, EndToEndRNNTanh_CUDA) {
ASSERT_TRUE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2)); }, true));
[](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }, true));
}
TEST_F(RNNTest, PrettyPrintRNNs) {
ASSERT_EQ(
c10::str(LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2))),
"torch::nn::LSTM(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)");
c10::str(LSTM(LSTMOptions(128, 64).layers(3).dropout(0.2))),
"torch::nn::LSTM(input_size=128, hidden_size=64, layers=3, dropout=0.2)");
ASSERT_EQ(
c10::str(GRU(GRUOptions(128, 64).num_layers(3).dropout(0.5))),
"torch::nn::GRU(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.5, bidirectional=false)");
c10::str(GRU(GRUOptions(128, 64).layers(3).dropout(0.5))),
"torch::nn::GRU(input_size=128, hidden_size=64, layers=3, dropout=0.5)");
ASSERT_EQ(
c10::str(RNN(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh))),
"torch::nn::RNN(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)");
c10::str(RNN(RNNOptions(128, 64).layers(3).dropout(0.2).tanh())),
"torch::nn::RNN(input_size=128, hidden_size=64, layers=3, dropout=0.2, activation=tanh)");
}
// This test assures that flatten_parameters does not crash,
// when bidirectional is set to true
// https://github.com/pytorch/pytorch/issues/19545
TEST_F(RNNTest, BidirectionalFlattenParameters) {
GRU gru(GRUOptions(100, 256).num_layers(2).bidirectional(true));
GRU gru(GRUOptions(100, 256).layers(2).bidirectional(true));
gru->flatten_parameters();
}
template <typename Impl>
void copyParameters(torch::nn::ModuleHolder<Impl>& target, std::string t_suffix,
const torch::nn::ModuleHolder<Impl>& source, std::string s_suffix) {
void copyParameters(torch::nn::ModuleHolder<Impl>& target, size_t t_i,
const torch::nn::ModuleHolder<Impl>& source, size_t s_i) {
at::NoGradGuard guard;
target->named_parameters()["weight_ih_l" + t_suffix].copy_(source->named_parameters()["weight_ih_l" + s_suffix]);
target->named_parameters()["weight_hh_l" + t_suffix].copy_(source->named_parameters()["weight_hh_l" + s_suffix]);
target->named_parameters()["bias_ih_l" + t_suffix].copy_(source->named_parameters()["bias_ih_l" + s_suffix]);
target->named_parameters()["bias_hh_l" + t_suffix].copy_(source->named_parameters()["bias_hh_l" + s_suffix]);
}
std::tuple<torch::Tensor, torch::Tensor> gru_output_to_device(
std::tuple<torch::Tensor, torch::Tensor> gru_output, torch::Device device) {
return std::make_tuple(
std::get<0>(gru_output).to(device),
std::get<1>(gru_output).to(device));
}
std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> lstm_output_to_device(
std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> lstm_output, torch::Device device) {
auto hidden_states = std::get<1>(lstm_output);
return std::make_tuple(
std::get<0>(lstm_output).to(device),
std::make_tuple(
std::get<0>(hidden_states).to(device),
std::get<1>(hidden_states).to(device)));
target->w_ih[t_i].copy_(source->w_ih[s_i]);
target->w_hh[t_i].copy_(source->w_hh[s_i]);
target->b_ih[t_i].copy_(source->b_ih[s_i]);
target->b_hh[t_i].copy_(source->b_hh[s_i]);
}
// This test is a port of python code introduced here:
@ -311,7 +264,7 @@ void BidirectionalGRUReverseForward(bool cuda) {
auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1});
auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1});
auto gru_options = GRUOptions(1, 1).num_layers(1).batch_first(false);
auto gru_options = GRUOptions(1, 1).layers(1).batch_first(false);
GRU bi_grus {gru_options.bidirectional(true)};
GRU reverse_gru {gru_options.bidirectional(false)};
@ -322,26 +275,28 @@ void BidirectionalGRUReverseForward(bool cuda) {
// Now make sure the weights of the reverse gru layer match
// ones of the (reversed) bidirectional's:
copyParameters(reverse_gru, "0", bi_grus, "0_reverse");
copyParameters(reverse_gru, 0, bi_grus, 1);
auto bi_output = bi_grus->forward(input);
auto reverse_output = reverse_gru->forward(input_reversed);
if (cuda) {
bi_output = gru_output_to_device(bi_output, torch::kCPU);
reverse_output = gru_output_to_device(reverse_output, torch::kCPU);
bi_output.output = bi_output.output.to(torch::kCPU);
bi_output.state = bi_output.state.to(torch::kCPU);
reverse_output.output = reverse_output.output.to(torch::kCPU);
reverse_output.state = reverse_output.state.to(torch::kCPU);
}
ASSERT_EQ(std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0));
auto size = std::get<0>(bi_output).size(0);
ASSERT_EQ(bi_output.output.size(0), reverse_output.output.size(0));
auto size = bi_output.output.size(0);
for (int i = 0; i < size; i++) {
ASSERT_EQ(std::get<0>(bi_output)[i][0][1].item<float>(),
std::get<0>(reverse_output)[size - 1 - i][0][0].item<float>());
ASSERT_EQ(bi_output.output[i][0][1].item<float>(),
reverse_output.output[size - 1 - i][0][0].item<float>());
}
// The hidden states of the reversed GRUs sits
// in the odd indices in the first dimension.
ASSERT_EQ(std::get<1>(bi_output)[1][0][0].item<float>(),
std::get<1>(reverse_output)[0][0][0].item<float>());
ASSERT_EQ(bi_output.state[1][0][0].item<float>(),
reverse_output.state[0][0][0].item<float>());
}
TEST_F(RNNTest, BidirectionalGRUReverseForward) {
@ -360,7 +315,7 @@ void BidirectionalLSTMReverseForwardTest(bool cuda) {
auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1});
auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1});
auto lstm_opt = LSTMOptions(1, 1).num_layers(1).batch_first(false);
auto lstm_opt = GRUOptions(1, 1).layers(1).batch_first(false);
LSTM bi_lstm {lstm_opt.bidirectional(true)};
LSTM reverse_lstm {lstm_opt.bidirectional(false)};
@ -372,28 +327,30 @@ void BidirectionalLSTMReverseForwardTest(bool cuda) {
// Now make sure the weights of the reverse lstm layer match
// ones of the (reversed) bidirectional's:
copyParameters(reverse_lstm, "0", bi_lstm, "0_reverse");
copyParameters(reverse_lstm, 0, bi_lstm, 1);
auto bi_output = bi_lstm->forward(input);
auto reverse_output = reverse_lstm->forward(input_reversed);
if (cuda) {
bi_output = lstm_output_to_device(bi_output, torch::kCPU);
reverse_output = lstm_output_to_device(reverse_output, torch::kCPU);
bi_output.output = bi_output.output.to(torch::kCPU);
bi_output.state = bi_output.state.to(torch::kCPU);
reverse_output.output = reverse_output.output.to(torch::kCPU);
reverse_output.state = reverse_output.state.to(torch::kCPU);
}
ASSERT_EQ(std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0));
auto size = std::get<0>(bi_output).size(0);
ASSERT_EQ(bi_output.output.size(0), reverse_output.output.size(0));
auto size = bi_output.output.size(0);
for (int i = 0; i < size; i++) {
ASSERT_EQ(std::get<0>(bi_output)[i][0][1].item<float>(),
std::get<0>(reverse_output)[size - 1 - i][0][0].item<float>());
ASSERT_EQ(bi_output.output[i][0][1].item<float>(),
reverse_output.output[size - 1 - i][0][0].item<float>());
}
// The hidden states of the reversed LSTM sits
// in the odd indices in the first dimension.
ASSERT_EQ(std::get<0>(std::get<1>(bi_output))[1][0][0].item<float>(),
std::get<0>(std::get<1>(reverse_output))[0][0][0].item<float>());
ASSERT_EQ(std::get<1>(std::get<1>(bi_output))[1][0][0].item<float>(),
std::get<1>(std::get<1>(reverse_output))[0][0][0].item<float>());
ASSERT_EQ(bi_output.state[0][1][0][0].item<float>(),
reverse_output.state[0][0][0][0].item<float>());
ASSERT_EQ(bi_output.state[1][1][0][0].item<float>(),
reverse_output.state[1][0][0][0].item<float>());
}
TEST_F(RNNTest, BidirectionalLSTMReverseForward) {
@ -406,15 +363,19 @@ TEST_F(RNNTest, BidirectionalLSTMReverseForward_CUDA) {
TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) {
// Create two GRUs with the same options
auto opt = GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
auto opt = GRUOptions(2, 4).layers(3).batch_first(false).bidirectional(true);
GRU gru_cpu {opt};
GRU gru_cuda {opt};
// Copy weights and biases from CPU GRU to CUDA GRU
{
at::NoGradGuard guard;
for (const auto& param : gru_cpu->named_parameters(/*recurse=*/false)) {
gru_cuda->named_parameters()[param.key()].copy_(gru_cpu->named_parameters()[param.key()]);
const auto num_directions = gru_cpu->options.bidirectional() ? 2 : 1;
for (int64_t layer = 0; layer < gru_cpu->options.layers(); layer++) {
for (auto direction = 0; direction < num_directions; direction++) {
const auto layer_idx = (layer * num_directions) + direction;
copyParameters(gru_cuda, layer_idx, gru_cpu, layer_idx);
}
}
}
@ -436,19 +397,20 @@ TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) {
auto output_cpu = gru_cpu->forward(input_cpu);
auto output_cuda = gru_cuda->forward(input_cuda);
output_cpu = gru_output_to_device(output_cpu, torch::kCPU);
output_cpu.output = output_cpu.output.to(torch::kCPU);
output_cpu.state = output_cpu.state.to(torch::kCPU);
// Assert that the output and state are equal on CPU and CUDA
ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim());
for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) {
ASSERT_EQ(std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i));
ASSERT_EQ(output_cpu.output.dim(), output_cuda.output.dim());
for (int i = 0; i < output_cpu.output.dim(); i++) {
ASSERT_EQ(output_cpu.output.size(i), output_cuda.output.size(i));
}
for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) {
for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) {
for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) {
for (int i = 0; i < output_cpu.output.size(0); i++) {
for (int j = 0; j < output_cpu.output.size(1); j++) {
for (int k = 0; k < output_cpu.output.size(2); k++) {
ASSERT_NEAR(
std::get<0>(output_cpu)[i][j][k].item<float>(),
std::get<0>(output_cuda)[i][j][k].item<float>(), 1e-5);
output_cpu.output[i][j][k].item<float>(),
output_cuda.output[i][j][k].item<float>(), 1e-5);
}
}
}
@ -456,15 +418,19 @@ TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) {
TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) {
// Create two LSTMs with the same options
auto opt = LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
auto opt = LSTMOptions(2, 4).layers(3).batch_first(false).bidirectional(true);
LSTM lstm_cpu {opt};
LSTM lstm_cuda {opt};
// Copy weights and biases from CPU LSTM to CUDA LSTM
{
at::NoGradGuard guard;
for (const auto& param : lstm_cpu->named_parameters(/*recurse=*/false)) {
lstm_cuda->named_parameters()[param.key()].copy_(lstm_cpu->named_parameters()[param.key()]);
const auto num_directions = lstm_cpu->options.bidirectional() ? 2 : 1;
for (int64_t layer = 0; layer < lstm_cpu->options.layers(); layer++) {
for (auto direction = 0; direction < num_directions; direction++) {
const auto layer_idx = (layer * num_directions) + direction;
copyParameters(lstm_cuda, layer_idx, lstm_cpu, layer_idx);
}
}
}
@ -485,68 +451,21 @@ TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) {
auto output_cpu = lstm_cpu->forward(input_cpu);
auto output_cuda = lstm_cuda->forward(input_cuda);
output_cpu = lstm_output_to_device(output_cpu, torch::kCPU);
output_cpu.output = output_cpu.output.to(torch::kCPU);
output_cpu.state = output_cpu.state.to(torch::kCPU);
// Assert that the output and state are equal on CPU and CUDA
ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim());
for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) {
ASSERT_EQ(std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i));
ASSERT_EQ(output_cpu.output.dim(), output_cuda.output.dim());
for (int i = 0; i < output_cpu.output.dim(); i++) {
ASSERT_EQ(output_cpu.output.size(i), output_cuda.output.size(i));
}
for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) {
for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) {
for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) {
for (int i = 0; i < output_cpu.output.size(0); i++) {
for (int j = 0; j < output_cpu.output.size(1); j++) {
for (int k = 0; k < output_cpu.output.size(2); k++) {
ASSERT_NEAR(
std::get<0>(output_cpu)[i][j][k].item<float>(),
std::get<0>(output_cuda)[i][j][k].item<float>(), 1e-5);
output_cpu.output[i][j][k].item<float>(),
output_cuda.output[i][j][k].item<float>(), 1e-5);
}
}
}
}
TEST_F(RNNTest, UsePackedSequenceAsInput) {
{
torch::manual_seed(0);
auto m = RNN(2, 3);
torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
auto rnn_output = m->forward_with_packed_input(packed_input);
auto expected_output = torch::tensor(
{{-0.0645, -0.7274, 0.4531},
{-0.3970, -0.6950, 0.6009},
{-0.3877, -0.7310, 0.6806}});
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
// Test passing optional argument to `RNN::forward_with_packed_input`
rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor());
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
}
{
torch::manual_seed(0);
auto m = LSTM(2, 3);
torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
auto rnn_output = m->forward_with_packed_input(packed_input);
auto expected_output = torch::tensor(
{{-0.2693, -0.1240, 0.0744},
{-0.3889, -0.1919, 0.1183},
{-0.4425, -0.2314, 0.1386}});
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
// Test passing optional argument to `LSTM::forward_with_packed_input`
rnn_output = m->forward_with_packed_input(packed_input, torch::nullopt);
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
}
{
torch::manual_seed(0);
auto m = GRU(2, 3);
torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
auto rnn_output = m->forward_with_packed_input(packed_input);
auto expected_output = torch::tensor(
{{-0.1134, 0.0467, 0.2336},
{-0.1189, 0.0502, 0.2960},
{-0.1138, 0.0484, 0.3110}});
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
// Test passing optional argument to `GRU::forward_with_packed_input`
rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor());
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
}
}

View File

@ -410,7 +410,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
" (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
")");
Sequential sequential_named({
@ -429,7 +429,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
" (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
" (batchnorm2d): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
")");
}
@ -598,21 +598,21 @@ TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) {
torch::manual_seed(0);
Sequential sequential(Identity(), RNN(2, 3));
auto x = torch::ones({2, 3, 2});
auto rnn_output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
auto rnn_output = sequential->forward<RNNOutput>(x);
auto expected_output = torch::tensor(
{{{-0.0645, -0.7274, 0.4531},
{-0.0645, -0.7274, 0.4531},
{-0.0645, -0.7274, 0.4531}},
{{-0.3970, -0.6950, 0.6009},
{-0.3970, -0.6950, 0.6009},
{-0.3970, -0.6950, 0.6009}}});
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
{{{0.0000, 0.0000, 0.4886},
{0.0000, 0.0000, 0.4886},
{0.0000, 0.0000, 0.4886}},
{{0.0000, 0.0000, 0.3723},
{0.0000, 0.0000, 0.3723},
{0.0000, 0.0000, 0.3723}}});
ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
}
{
torch::manual_seed(0);
Sequential sequential(Identity(), LSTM(2, 3));
auto x = torch::ones({2, 3, 2});
auto rnn_output = sequential->forward<std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>>(x);
auto rnn_output = sequential->forward<RNNOutput>(x);
auto expected_output = torch::tensor(
{{{-0.2693, -0.1240, 0.0744},
{-0.2693, -0.1240, 0.0744},
@ -620,13 +620,13 @@ TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) {
{{-0.3889, -0.1919, 0.1183},
{-0.3889, -0.1919, 0.1183},
{-0.3889, -0.1919, 0.1183}}});
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
}
{
torch::manual_seed(0);
Sequential sequential(Identity(), GRU(2, 3));
auto x = torch::ones({2, 3, 2});
auto rnn_output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
auto rnn_output = sequential->forward<RNNOutput>(x);
auto expected_output = torch::tensor(
{{{-0.1134, 0.0467, 0.2336},
{-0.1134, 0.0467, 0.2336},
@ -634,7 +634,7 @@ TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) {
{{-0.1189, 0.0502, 0.2960},
{-0.1189, 0.0502, 0.2960},
{-0.1189, 0.0502, 0.2960}}});
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
}
{
torch::manual_seed(0);

View File

@ -81,9 +81,9 @@ torch::nn::InstanceNorm3d|Yes|No
torch::nn::LayerNorm|Yes|No
torch::nn::LocalResponseNorm|Yes|No
torch::nn::CrossMapLRN2d|Yes|No
torch::nn::RNN|Yes|No
torch::nn::LSTM|Yes|No
torch::nn::GRU|Yes|No
torch::nn::RNN|No|No
torch::nn::LSTM|No|No
torch::nn::GRU|No|No
torch::nn::RNNCell|Yes|No
torch::nn::LSTMCell|Yes|No
torch::nn::GRUCell|Yes|No

View File

@ -122,10 +122,6 @@ TORCH_ENUM_DECLARE(BatchMean)
TORCH_ENUM_DECLARE(Zeros)
TORCH_ENUM_DECLARE(Border)
TORCH_ENUM_DECLARE(Reflection)
TORCH_ENUM_DECLARE(RNN_TANH)
TORCH_ENUM_DECLARE(RNN_RELU)
TORCH_ENUM_DECLARE(LSTM)
TORCH_ENUM_DECLARE(GRU)
namespace torch {
namespace enumtype {
@ -161,10 +157,6 @@ struct _compute_enum_name {
TORCH_ENUM_PRETTY_PRINT(Zeros)
TORCH_ENUM_PRETTY_PRINT(Border)
TORCH_ENUM_PRETTY_PRINT(Reflection)
TORCH_ENUM_PRETTY_PRINT(RNN_TANH)
TORCH_ENUM_PRETTY_PRINT(RNN_RELU)
TORCH_ENUM_PRETTY_PRINT(LSTM)
TORCH_ENUM_PRETTY_PRINT(GRU)
};
template <typename V>

View File

@ -5,7 +5,6 @@
#include <torch/nn/modules/common.h>
#include <torch/nn/modules/dropout.h>
#include <torch/nn/pimpl.h>
#include <torch/nn/utils/rnn.h>
#include <torch/types.h>
#include <ATen/ATen.h>
@ -16,23 +15,36 @@
#include <memory>
#include <vector>
using namespace torch::nn::utils::rnn;
namespace torch {
namespace nn {
/// The output of a single invocation of an RNN module's `forward()` method.
struct TORCH_API RNNOutput {
/// The result of applying the specific RNN algorithm
/// to the input tensor and input state.
Tensor output;
/// The new, updated state that can be fed into the RNN
/// in the next forward step.
Tensor state;
};
namespace detail {
/// Base class for all RNN implementations (intended for code sharing).
template <typename Derived>
class TORCH_API RNNImplBase : public torch::nn::Cloneable<Derived> {
public:
explicit RNNImplBase(const RNNOptionsBase& options_);
/// These must line up with the CUDNN mode codes:
/// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t
enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
explicit RNNImplBase(
const RNNOptionsBase& options_,
optional<CuDNNMode> cudnn_mode = nullopt,
int64_t number_of_gates = 1);
/// Initializes the parameters of the RNN module.
void reset() override;
void reset_parameters();
/// Overrides `nn::Module::to()` to call `flatten_parameters()` after the
/// original operation.
void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false)
@ -53,32 +65,52 @@ class TORCH_API RNNImplBase : public torch::nn::Cloneable<Derived> {
/// called once upon construction, inside `reset()`.
void flatten_parameters();
std::vector<Tensor> all_weights() const;
/// The RNN's options.
RNNOptionsBase options_base;
RNNOptionsBase options;
/// The weights for `input x hidden` gates.
std::vector<Tensor> w_ih;
/// The weights for `hidden x hidden` gates.
std::vector<Tensor> w_hh;
/// The biases for `input x hidden` gates.
std::vector<Tensor> b_ih;
/// The biases for `hidden x hidden` gates.
std::vector<Tensor> b_hh;
protected:
// Resets flat_weights_
// Note: be v. careful before removing this, as 3rd party device types
// likely rely on this behavior to properly .to() modules like LSTM.
void reset_flat_weights();
/// The function signature of `rnn_relu`, `rnn_tanh` and `gru`.
using RNNFunctionSignature = std::tuple<Tensor, Tensor>(
/*input=*/const Tensor&,
/*state=*/const Tensor&,
/*params=*/TensorList,
/*has_biases=*/bool,
/*layers=*/int64_t,
/*dropout=*/double,
/*train=*/bool,
/*bidirectional=*/bool,
/*batch_first=*/bool);
void check_input(const Tensor& input, const Tensor& batch_sizes) const;
/// A generic `forward()` used for RNN and GRU (but not LSTM!). Takes the ATen
/// RNN function as first argument.
RNNOutput generic_forward(
std::function<RNNFunctionSignature> function,
const Tensor& input,
Tensor state);
std::tuple<int64_t, int64_t, int64_t> get_expected_hidden_size(const Tensor& input, const Tensor& batch_sizes) const;
/// Returns a flat vector of all weights, with layer weights following each
/// other sequentially in (w_ih, w_hh, b_ih, b_hh) order.
std::vector<Tensor> flat_weights() const;
void check_hidden_size(
const Tensor& hx,
std::tuple<int64_t, int64_t, int64_t> expected_hidden_size,
std::string msg = "Expected hidden size {1}, got {2}") const;
/// Very simple check if any of the parameters (weights, biases) are the same.
bool any_parameters_alias() const;
void check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) const;
/// The number of gate weights/biases required by the RNN subclass.
int64_t number_of_gates_;
Tensor permute_hidden(Tensor hx, const Tensor& permutation) const;
/// The cuDNN RNN mode, if this RNN subclass has any.
optional<CuDNNMode> cudnn_mode_;
std::vector<std::string> flat_weights_names_;
std::vector<std::vector<std::string>> all_weights_;
/// The cached result of the latest `flat_weights()` call.
std::vector<Tensor> flat_weights_;
};
} // namespace detail
@ -86,139 +118,84 @@ class TORCH_API RNNImplBase : public torch::nn::Cloneable<Derived> {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// A multi-layer Elman RNN module with Tanh or ReLU activation.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.RNN to learn
/// about the exact behavior of this module.
///
/// See the documentation for `torch::nn::RNNOptions` class to learn what
/// constructor arguments are supported for this module.
///
/// Example:
/// ```
/// RNN model(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh));
/// ```
/// See https://pytorch.org/docs/master/nn.html#torch.nn.RNN to learn about the
/// exact behavior of this module.
class TORCH_API RNNImpl : public detail::RNNImplBase<RNNImpl> {
public:
RNNImpl(int64_t input_size, int64_t hidden_size)
: RNNImpl(RNNOptions(input_size, hidden_size)) {}
explicit RNNImpl(const RNNOptions& options_);
std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {});
/// Pretty prints the `RNN` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
/// Applies the `RNN` module to an input sequence and input state.
/// The `input` should follow a `(sequence, batch, features)` layout unless
/// `batch_first` is true, in which case the layout should be `(batch,
/// sequence, features)`.
RNNOutput forward(const Tensor& input, Tensor state = {});
protected:
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
public:
std::tuple<PackedSequence, Tensor> forward_with_packed_input(const PackedSequence& packed_input, Tensor hx = {});
RNNOptions options;
protected:
std::tuple<Tensor, Tensor> forward_helper(
const Tensor& input,
const Tensor& batch_sizes,
const Tensor& sorted_indices,
int64_t max_batch_size,
Tensor hx);
};
/// A `ModuleHolder` subclass for `RNNImpl`.
/// See the documentation for `RNNImpl` class to learn what methods it
/// provides, and examples of how to use `RNN` with `torch::nn::RNNOptions`.
/// See the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
/// See the documentation for `RNNImpl` class to learn what methods it provides,
/// or the documentation for `ModuleHolder` to learn about PyTorch's module
/// storage semantics.
TORCH_MODULE(RNN);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// A multi-layer long-short-term-memory (LSTM) module.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.LSTM to learn
/// about the exact behavior of this module.
///
/// See the documentation for `torch::nn::LSTMOptions` class to learn what
/// constructor arguments are supported for this module.
///
/// Example:
/// ```
/// LSTM model(LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
/// ```
/// See https://pytorch.org/docs/master/nn.html#torch.nn.LSTM to learn about the
/// exact behavior of this module.
class TORCH_API LSTMImpl : public detail::RNNImplBase<LSTMImpl> {
public:
LSTMImpl(int64_t input_size, int64_t hidden_size)
: LSTMImpl(LSTMOptions(input_size, hidden_size)) {}
explicit LSTMImpl(const LSTMOptions& options_);
std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward(
const Tensor& input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
/// Applies the `LSTM` module to an input sequence and input state.
/// The `input` should follow a `(sequence, batch, features)` layout unless
/// `batch_first` is true, in which case the layout should be `(batch,
/// sequence, features)`.
RNNOutput forward(const Tensor& input, Tensor state = {});
protected:
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())})
public:
std::tuple<PackedSequence, std::tuple<Tensor, Tensor>> forward_with_packed_input(
const PackedSequence& packed_input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
LSTMOptions options;
protected:
void check_forward_args(const Tensor& input, std::tuple<Tensor, Tensor> hidden, const Tensor& batch_sizes) const;
std::tuple<Tensor, Tensor> permute_hidden(std::tuple<Tensor, Tensor> hx, const Tensor& permutation) const;
std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward_helper(
const Tensor& input,
const Tensor& batch_sizes,
const Tensor& sorted_indices,
int64_t max_batch_size,
torch::optional<std::tuple<Tensor, Tensor>> hx_opt);
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
};
/// A `ModuleHolder` subclass for `LSTMImpl`.
/// See the documentation for `LSTMImpl` class to learn what methods it
/// provides, and examples of how to use `LSTM` with `torch::nn::LSTMOptions`.
/// See the documentation for `ModuleHolder` to learn about PyTorch's
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(LSTM);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// A multi-layer gated recurrent unit (GRU) module.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.GRU to learn
/// about the exact behavior of this module.
///
/// See the documentation for `torch::nn::GRUOptions` class to learn what
/// constructor arguments are supported for this module.
///
/// Example:
/// ```
/// GRU model(GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
/// ```
/// See https://pytorch.org/docs/master/nn.html#torch.nn.GRU to learn about the
/// exact behavior of this module.
class TORCH_API GRUImpl : public detail::RNNImplBase<GRUImpl> {
public:
GRUImpl(int64_t input_size, int64_t hidden_size)
: GRUImpl(GRUOptions(input_size, hidden_size)) {}
explicit GRUImpl(const GRUOptions& options_);
std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {});
/// Applies the `GRU` module to an input sequence and input state.
/// The `input` should follow a `(sequence, batch, features)` layout unless
/// `batch_first` is true, in which case the layout should be `(batch,
/// sequence, features)`.
RNNOutput forward(const Tensor& input, Tensor state = {});
protected:
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::Tensor())})
public:
std::tuple<PackedSequence, Tensor> forward_with_packed_input(const PackedSequence& packed_input, Tensor hx = {});
GRUOptions options;
protected:
std::tuple<Tensor, Tensor> forward_helper(
const Tensor& input,
const Tensor& batch_sizes,
const Tensor& sorted_indices,
int64_t max_batch_size,
Tensor hx);
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
};
/// A `ModuleHolder` subclass for `GRUImpl`.
/// See the documentation for `GRUImpl` class to learn what methods it
/// provides, and examples of how to use `GRU` with `torch::nn::GRUOptions`.
/// See the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
/// See the documentation for `GRUImpl` class to learn what methods it provides,
/// or the documentation for `ModuleHolder` to learn about PyTorch's module
/// storage semantics.
TORCH_MODULE(GRU);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -10,137 +10,65 @@ namespace nn {
namespace detail {
/// Common options for RNN, LSTM and GRU modules.
/// Common options for LSTM and GRU modules.
struct TORCH_API RNNOptionsBase {
typedef c10::variant<
enumtype::kLSTM,
enumtype::kGRU,
enumtype::kRNN_TANH,
enumtype::kRNN_RELU> rnn_options_base_mode_t;
RNNOptionsBase(rnn_options_base_mode_t mode, int64_t input_size, int64_t hidden_size);
TORCH_ARG(rnn_options_base_mode_t, mode);
RNNOptionsBase(int64_t input_size, int64_t hidden_size);
virtual ~RNNOptionsBase() = default;
/// The number of features of a single sample in the input sequence `x`.
TORCH_ARG(int64_t, input_size);
/// The number of features in the hidden state `h`.
TORCH_ARG(int64_t, hidden_size);
/// The number of recurrent layers (cells) to use.
TORCH_ARG(int64_t, num_layers) = 1;
TORCH_ARG(int64_t, layers) = 1;
/// Whether a bias term should be added to all linear operations.
TORCH_ARG(bool, bias) = true;
/// If true, the input sequence should be provided as `(batch, sequence,
/// features)`. If false (default), the expected layout is `(sequence, batch,
/// features)`.
TORCH_ARG(bool, batch_first) = false;
TORCH_ARG(bool, with_bias) = true;
/// If non-zero, adds dropout with the given probability to the output of each
/// RNN layer, except the final layer.
TORCH_ARG(double, dropout) = 0.0;
/// Whether to make the RNN bidirectional.
TORCH_ARG(bool, bidirectional) = false;
/// If true, the input sequence should be provided as `(batch, sequence,
/// features)`. If false (default), the expected layout is `(sequence, batch,
/// features)`.
TORCH_ARG(bool, batch_first) = false;
};
} // namespace detail
/// Options for the `RNN` module.
///
/// Example:
/// ```
/// RNN model(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh));
/// ```
struct TORCH_API RNNOptions {
typedef c10::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
enum class RNNActivation : uint32_t {ReLU, Tanh};
/// Options for RNN modules.
struct TORCH_API RNNOptions {
RNNOptions(int64_t input_size, int64_t hidden_size);
/// The number of expected features in the input `x`
/// Sets the activation after linear operations to `tanh`.
RNNOptions& tanh();
/// Sets the activation after linear operations to `relu`.
RNNOptions& relu();
/// The number of features of a single sample in the input sequence `x`.
TORCH_ARG(int64_t, input_size);
/// The number of features in the hidden state `h`
/// The number of features in the hidden state `h`.
TORCH_ARG(int64_t, hidden_size);
/// Number of recurrent layers. E.g., setting ``num_layers=2``
/// would mean stacking two RNNs together to form a `stacked RNN`,
/// with the second RNN taking in outputs of the first RNN and
/// computing the final results. Default: 1
TORCH_ARG(int64_t, num_layers) = 1;
/// The non-linearity to use. Can be either ``torch::kTanh`` or ``torch::kReLU``. Default: ``torch::kTanh``
TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh;
/// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
/// Default: ``true``
TORCH_ARG(bool, bias) = true;
/// If ``true``, then the input and output tensors are provided
/// as `(batch, seq, feature)`. Default: ``false``
TORCH_ARG(bool, batch_first) = false;
/// If non-zero, introduces a `Dropout` layer on the outputs of each
/// RNN layer except the last layer, with dropout probability equal to
/// `dropout`. Default: 0
/// The number of recurrent layers (cells) to use.
TORCH_ARG(int64_t, layers) = 1;
/// Whether a bias term should be added to all linear operations.
TORCH_ARG(bool, with_bias) = true;
/// If non-zero, adds dropout with the given probability to the output of each
/// RNN layer, except the final layer.
TORCH_ARG(double, dropout) = 0.0;
/// If ``true``, becomes a bidirectional RNN. Default: ``false``
/// Whether to make the RNN bidirectional.
TORCH_ARG(bool, bidirectional) = false;
/// If true, the input sequence should be provided as `(batch, sequence,
/// features)`. If false (default), the expected layout is `(sequence, batch,
/// features)`.
TORCH_ARG(bool, batch_first) = false;
/// The activation to use after linear operations.
TORCH_ARG(RNNActivation, activation) = RNNActivation::ReLU;
};
/// Options for the `LSTM` module.
///
/// Example:
/// ```
/// LSTM model(LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
/// ```
struct TORCH_API LSTMOptions {
LSTMOptions(int64_t input_size, int64_t hidden_size);
/// The number of expected features in the input `x`
TORCH_ARG(int64_t, input_size);
/// The number of features in the hidden state `h`
TORCH_ARG(int64_t, hidden_size);
/// Number of recurrent layers. E.g., setting ``num_layers=2``
/// would mean stacking two LSTMs together to form a `stacked LSTM`,
/// with the second LSTM taking in outputs of the first LSTM and
/// computing the final results. Default: 1
TORCH_ARG(int64_t, num_layers) = 1;
/// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
/// Default: ``true``
TORCH_ARG(bool, bias) = true;
/// If ``true``, then the input and output tensors are provided
/// as (batch, seq, feature). Default: ``false``
TORCH_ARG(bool, batch_first) = false;
/// If non-zero, introduces a `Dropout` layer on the outputs of each
/// LSTM layer except the last layer, with dropout probability equal to
/// `dropout`. Default: 0
TORCH_ARG(double, dropout) = 0.0;
/// If ``true``, becomes a bidirectional LSTM. Default: ``false``
TORCH_ARG(bool, bidirectional) = false;
};
/// Options for the `GRU` module.
///
/// Example:
/// ```
/// GRU model(GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
/// ```
struct TORCH_API GRUOptions {
GRUOptions(int64_t input_size, int64_t hidden_size);
/// The number of expected features in the input `x`
TORCH_ARG(int64_t, input_size);
/// The number of features in the hidden state `h`
TORCH_ARG(int64_t, hidden_size);
/// Number of recurrent layers. E.g., setting ``num_layers=2``
/// would mean stacking two GRUs together to form a `stacked GRU`,
/// with the second GRU taking in outputs of the first GRU and
/// computing the final results. Default: 1
TORCH_ARG(int64_t, num_layers) = 1;
/// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
/// Default: ``true``
TORCH_ARG(bool, bias) = true;
/// If ``true``, then the input and output tensors are provided
/// as (batch, seq, feature). Default: ``false``
TORCH_ARG(bool, batch_first) = false;
/// If non-zero, introduces a `Dropout` layer on the outputs of each
/// GRU layer except the last layer, with dropout probability equal to
/// `dropout`. Default: 0
TORCH_ARG(double, dropout) = 0.0;
/// If ``true``, becomes a bidirectional GRU. Default: ``false``
TORCH_ARG(bool, bidirectional) = false;
};
using LSTMOptions = detail::RNNOptionsBase;
using GRUOptions = detail::RNNOptionsBase;
namespace detail {

View File

@ -30,7 +30,3 @@ TORCH_ENUM_DEFINE(BatchMean)
TORCH_ENUM_DEFINE(Zeros)
TORCH_ENUM_DEFINE(Border)
TORCH_ENUM_DEFINE(Reflection)
TORCH_ENUM_DEFINE(RNN_TANH)
TORCH_ENUM_DEFINE(RNN_RELU)
TORCH_ENUM_DEFINE(LSTM)
TORCH_ENUM_DEFINE(GRU)

View File

@ -1,5 +1,6 @@
#include <torch/nn/modules/rnn.h>
#include <torch/nn/modules/dropout.h>
#include <torch/nn/init.h>
#include <torch/types.h>
#include <torch/utils.h>
@ -18,193 +19,65 @@
#include <utility>
#include <vector>
using namespace torch::nn::utils::rnn;
namespace torch {
namespace nn {
/// These must line up with the CUDNN mode codes:
/// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t
enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
CuDNNMode get_cudnn_mode_for_rnn(detail::RNNOptionsBase::rnn_options_base_mode_t mode) {
if (c10::get_if<enumtype::kRNN_RELU>(&mode)) {
return CuDNNMode::RNN_RELU;
} else if (c10::get_if<enumtype::kRNN_TANH>(&mode)) {
return CuDNNMode::RNN_TANH;
} else if (c10::get_if<enumtype::kLSTM>(&mode)) {
return CuDNNMode::LSTM;
} else if (c10::get_if<enumtype::kGRU>(&mode)) {
return CuDNNMode::GRU;
} else {
TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(mode));
}
}
Tensor apply_permutation(const Tensor& tensor, const Tensor& permutation, int64_t dim = 1) {
return tensor.index_select(dim, permutation);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
namespace detail {
template <typename Derived>
RNNImplBase<Derived>::RNNImplBase(const RNNOptionsBase& options_)
: options_base(options_) {
RNNImplBase<Derived>::RNNImplBase(
const RNNOptionsBase& options_,
optional<CuDNNMode> cudnn_mode,
int64_t number_of_gates)
: options(options_),
number_of_gates_(number_of_gates),
cudnn_mode_(std::move(cudnn_mode)) {
reset();
}
template <typename Derived>
void RNNImplBase<Derived>::reset() {
const int64_t num_directions = options_base.bidirectional() ? 2 : 1;
const auto num_directions = options.bidirectional() ? 2 : 1;
TORCH_CHECK(
0 <= options_base.dropout() && options_base.dropout() <= 1,
"dropout should be a number in range [0, 1] ",
"representing the probability of an element being ",
"zeroed");
w_ih.resize(options.layers() * num_directions);
w_hh.resize(options.layers() * num_directions);
b_ih.resize(options.layers() * num_directions);
b_hh.resize(options.layers() * num_directions);
if (options_base.dropout() > 0 && options_base.num_layers() == 1) {
TORCH_WARN(
"dropout option adds dropout after all but last ",
"recurrent layer, so non-zero dropout expects ",
"num_layers greater than 1, but got dropout=", options_base.dropout(), " and ",
"num_layers=", options_base.num_layers());
}
const int64_t gate_size = options.hidden_size() * number_of_gates_;
int64_t gate_size = 0;
if (c10::get_if<enumtype::kLSTM>(&options_base.mode())) {
gate_size = 4 * options_base.hidden_size();
} else if (c10::get_if<enumtype::kGRU>(&options_base.mode())) {
gate_size = 3 * options_base.hidden_size();
} else if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
gate_size = options_base.hidden_size();
} else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
gate_size = options_base.hidden_size();
} else {
TORCH_CHECK(false, "Unrecognized RNN mode: " + torch::enumtype::get_enum_name(options_base.mode()));
}
for (int64_t layer = 0; layer < options.layers(); ++layer) {
for (auto direction = 0; direction < num_directions; direction++) {
const auto layer_input_size = layer == 0 ? options.input_size() :
options.hidden_size() * num_directions;
const auto suffix = direction == 1 ? "_reverse" : "";
const auto layer_idx = (layer * num_directions) + direction;
w_ih[layer_idx] = this->register_parameter(
"weight_ih_l" + std::to_string(layer) + suffix,
torch::empty({gate_size, layer_input_size}));
w_hh[layer_idx] = this->register_parameter(
"weight_hh_l" + std::to_string(layer) + suffix,
torch::empty({gate_size, options.hidden_size()}));
flat_weights_names_ = {};
all_weights_ = {};
for (int64_t layer = 0; layer < options_base.num_layers(); layer++) {
for (int64_t direction = 0; direction < num_directions; direction++) {
int64_t layer_input_size = layer == 0 ? options_base.input_size() : options_base.hidden_size() * num_directions;
auto w_ih = torch::empty({gate_size, layer_input_size});
auto w_hh = torch::empty({gate_size, options_base.hidden_size()});
auto b_ih = torch::empty({gate_size});
// Second bias vector included for CuDNN compatibility. Only one
// bias vector is needed in standard definition.
auto b_hh = torch::empty({gate_size});
std::vector<Tensor> layer_params = {w_ih, w_hh, b_ih, b_hh};
std::string suffix = direction == 1 ? "_reverse" : "";
std::vector<std::string> param_names = {"weight_ih_l{layer}{suffix}", "weight_hh_l{layer}{suffix}"};
if (options_base.bias()) {
param_names.emplace_back("bias_ih_l{layer}{suffix}");
param_names.emplace_back("bias_hh_l{layer}{suffix}");
if (options.with_bias()) {
b_ih[layer_idx] = this->register_parameter(
"bias_ih_l" + std::to_string(layer) + suffix,
torch::empty({gate_size}));
b_hh[layer_idx] = this->register_parameter(
"bias_hh_l" + std::to_string(layer) + suffix,
torch::empty({gate_size}));
}
for (size_t i = 0; i < param_names.size(); i++) { // NOLINT(modernize-loop-convert)
std::string x = std::regex_replace(param_names[i], std::regex("\\{layer\\}"), c10::str(layer));
x = std::regex_replace(x, std::regex("\\{suffix\\}"), c10::str(suffix));
param_names[i] = x;
}
for (size_t i = 0; i < param_names.size(); i++) {
auto name = param_names[i];
auto param = layer_params[i];
this->register_parameter(name, param);
}
flat_weights_names_.insert(flat_weights_names_.end(), param_names.begin(), param_names.end());
all_weights_.emplace_back(param_names);
}
}
flat_weights_ = {};
for (const auto& wn : flat_weights_names_) {
auto named_parameters = this->named_parameters(/*recurse=*/false);
if (named_parameters.contains(wn)) {
flat_weights_.emplace_back(named_parameters[wn]);
} else {
flat_weights_.emplace_back(Tensor());
}
}
this->flatten_parameters();
this->reset_parameters();
}
template <typename Derived>
void RNNImplBase<Derived>::flatten_parameters() {
// Resets parameter data pointer so that they can use faster code paths.
//
// Right now, this works only if the module is on the GPU and cuDNN is enabled.
// Otherwise, it's a no-op.
// Short-circuits if flat_weights_ is only partially instantiated
if (flat_weights_.size() != flat_weights_names_.size()) {
return;
}
// Short-circuits if any tensor in self.flat_weights_ is not acceptable to cuDNN
// or the tensors in flat_weights_ are of different dtypes
auto first_fw = flat_weights_[0];
auto dtype = first_fw.dtype();
for (const auto& fw : flat_weights_) {
if (!(fw.dtype() == dtype) ||
!fw.is_cuda() ||
!torch::cudnn_is_acceptable(fw)) {
return;
}
}
// If any parameters alias, we fall back to the slower, copying code path. This is
// a sufficient check, because overlapping parameter buffers that don't completely
// alias would break the assumptions of the uniqueness check in
// Module::named_parameters().
std::unordered_set<void*> unique_data_ptrs;
for (const auto& p : flat_weights_) {
unique_data_ptrs.emplace(p.data_ptr());
}
if (unique_data_ptrs.size() != flat_weights_.size()) {
return;
}
{
torch::DeviceGuard device_guard(first_fw.device());
// Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
// an inplace operation on self.flat_weights_
{
torch::NoGradGuard no_grad;
if (torch::_use_cudnn_rnn_flatten_weight()) {
torch::_cudnn_rnn_flatten_weight(
flat_weights_,
options_base.bias() ? 4 : 2,
options_base.input_size(),
static_cast<int64_t>(get_cudnn_mode_for_rnn(options_base.mode())),
options_base.hidden_size(),
options_base.num_layers(),
options_base.batch_first(),
options_base.bidirectional());
}
NoGradGuard no_grad;
const auto stdv = 1.0 / std::sqrt(options.hidden_size());
for (auto& p : this->parameters()) {
p.uniform_(-stdv, stdv);
}
}
}
template <typename Derived>
void RNNImplBase<Derived>::reset_flat_weights() {
flat_weights_ = {};
for (const auto& wn : flat_weights_names_) {
auto named_parameters = this->named_parameters(/*recurse=*/false);
if (named_parameters.contains(wn)) {
flat_weights_.emplace_back(named_parameters[wn]);
} else {
flat_weights_.emplace_back(Tensor());
}
}
flatten_parameters();
}
template <typename Derived>
@ -213,113 +86,126 @@ void RNNImplBase<Derived>::to(
torch::Dtype dtype,
bool non_blocking) {
nn::Module::to(device, dtype, non_blocking);
reset_flat_weights();
flatten_parameters();
}
template <typename Derived>
void RNNImplBase<Derived>::to(torch::Dtype dtype, bool non_blocking) {
nn::Module::to(dtype, non_blocking);
reset_flat_weights();
flatten_parameters();
}
template <typename Derived>
void RNNImplBase<Derived>::to(torch::Device device, bool non_blocking) {
nn::Module::to(device, non_blocking);
reset_flat_weights();
const auto num_directions = options.bidirectional() ? 2 : 1;
for (int64_t layer = 0; layer < options.layers(); layer++) {
for (auto direction = 0; direction < num_directions; direction++) {
const auto layer_idx = (layer * num_directions) + direction;
w_ih[layer_idx] = w_ih[layer_idx].to(device, non_blocking);
w_hh[layer_idx] = w_hh[layer_idx].to(device, non_blocking);
if (options.with_bias()) {
b_ih[layer_idx] = b_ih[layer_idx].to(device, non_blocking);
b_hh[layer_idx] = b_hh[layer_idx].to(device, non_blocking);
}
}
}
flatten_parameters();
}
template <typename Derived>
void RNNImplBase<Derived>::reset_parameters() {
const double stdv = 1.0 / std::sqrt(options_base.hidden_size());
for (auto& weight : this->parameters()) {
init::uniform_(weight, -stdv, stdv);
}
}
template <typename Derived>
void RNNImplBase<Derived>::check_input(const Tensor& input, const Tensor& batch_sizes) const {
int64_t expected_input_dim = batch_sizes.defined() ? 2 : 3;
TORCH_CHECK(
input.dim() == expected_input_dim,
"input must have ", expected_input_dim, " dimensions, got ", input.dim());
TORCH_CHECK(
options_base.input_size() == input.size(-1),
"input.size(-1) must be equal to input_size. Expected ", options_base.input_size(), ", got ", input.size(-1));
}
template <typename Derived>
std::tuple<int64_t, int64_t, int64_t> RNNImplBase<Derived>::get_expected_hidden_size(
const Tensor& input, const Tensor& batch_sizes) const {
int64_t mini_batch = 0;
if (batch_sizes.defined()) {
mini_batch = batch_sizes[0].item<int64_t>();
} else {
mini_batch = options_base.batch_first() ? input.size(0) : input.size(1);
}
int64_t num_directions = options_base.bidirectional() ? 2 : 1;
return std::make_tuple(options_base.num_layers() * num_directions, mini_batch, options_base.hidden_size());
}
template <typename Derived>
void RNNImplBase<Derived>::check_hidden_size(
const Tensor& hx,
std::tuple<int64_t, int64_t, int64_t> expected_hidden_size,
std::string msg) const {
auto expected_hidden_size_vec = std::vector<int64_t>({
std::get<0>(expected_hidden_size),
std::get<1>(expected_hidden_size),
std::get<2>(expected_hidden_size),
});
if (hx.sizes() != expected_hidden_size_vec) {
msg = std::regex_replace(msg, std::regex("\\{1\\}"), c10::str(expected_hidden_size_vec));
msg = std::regex_replace(msg, std::regex("\\{2\\}"), c10::str(hx.sizes()));
TORCH_CHECK(false, msg);
}
}
template <typename Derived>
void RNNImplBase<Derived>::check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) const {
this->check_input(input, batch_sizes);
auto expected_hidden_size = this->get_expected_hidden_size(input, batch_sizes);
this->check_hidden_size(hidden, expected_hidden_size);
}
template <typename Derived>
Tensor RNNImplBase<Derived>::permute_hidden(Tensor hx, const Tensor& permutation) const {
if (!permutation.defined()) {
return hx;
}
return apply_permutation(hx, permutation);
}
template <typename Derived>
void RNNImplBase<Derived>::pretty_print(std::ostream& stream) const {
const std::string name = this->name();
const std::string name_without_impl = name.substr(0, name.size() - 4);
stream << std::boolalpha << name_without_impl << "(input_size=" << options_base.input_size()
<< ", hidden_size=" << options_base.hidden_size()
<< ", num_layers=" << options_base.num_layers()
<< ", bias=" << options_base.bias()
<< ", batch_first=" << options_base.batch_first()
<< ", dropout=" << options_base.dropout()
<< ", bidirectional=" << options_base.bidirectional()
stream << name_without_impl << "(input_size=" << options.input_size()
<< ", hidden_size=" << options.hidden_size()
<< ", layers=" << options.layers() << ", dropout=" << options.dropout()
<< ")";
}
template <typename Derived>
std::vector<Tensor> RNNImplBase<Derived>::all_weights() const {
std::vector<Tensor> result = {};
auto named_parameters = this->named_parameters(/*recurse=*/false);
for (const auto& weights : all_weights_) {
for (const auto& weight : weights) {
result.emplace_back(named_parameters[weight]);
void RNNImplBase<Derived>::flatten_parameters() {
// Cache the flattened weight and bias vector.
flat_weights_ = flat_weights();
if (!cudnn_mode_ || !torch::cudnn_is_acceptable(w_ih.at(0))) {
return;
}
NoGradGuard no_grad;
if (torch::_use_cudnn_rnn_flatten_weight()) {
torch::_cudnn_rnn_flatten_weight(
flat_weights_,
/*weight_stride0=*/options.with_bias() ? 4 : 2,
options.input_size(),
static_cast<int64_t>(*cudnn_mode_),
options.hidden_size(),
options.layers(),
/*batch_first=*/options.batch_first(),
/*bidirectional=*/options.bidirectional());
}
}
template <typename Derived>
RNNOutput RNNImplBase<Derived>::generic_forward(
std::function<RNNFunctionSignature> function,
const Tensor& input,
Tensor state) {
if (!state.defined()) {
// #layers, batch size, state size
const auto batch_size = input.size(options.batch_first() ? 0 : 1);
const auto num_directions = options.bidirectional() ? 2 : 1;
state = torch::zeros(
{options.layers() * num_directions, batch_size, options.hidden_size()},
input.options());
}
Tensor output, new_state;
std::tie(output, new_state) = function(
input,
std::move(state),
flat_weights_,
options.with_bias(),
options.layers(),
options.dropout(),
this->is_training(),
options.bidirectional(),
options.batch_first());
return {output, new_state};
}
template <typename Derived>
std::vector<Tensor> RNNImplBase<Derived>::flat_weights() const {
// Organize all weights in a flat vector in the order
// (w_ih, w_hh, b_ih, b_hh), repeated for each layer (next to each other).
std::vector<Tensor> flat;
const auto num_directions = options.bidirectional() ? 2 : 1;
for (int64_t layer = 0; layer < options.layers(); layer++) {
for (auto direction = 0; direction < num_directions; direction++) {
const auto layer_idx = (layer * num_directions) + direction;
flat.push_back(w_ih[layer_idx]);
flat.push_back(w_hh[layer_idx]);
if (options.with_bias()) {
flat.push_back(b_ih[layer_idx]);
flat.push_back(b_hh[layer_idx]);
}
}
}
return result;
return flat;
}
template <typename Derived>
bool RNNImplBase<Derived>::any_parameters_alias() const {
// If any parameters alias, we fall back to the slower, copying code path.
// This is a sufficient check, because overlapping parameter buffers that
// don't completely alias would break the assumptions of the uniqueness check
// in Module.named_parameters().
std::unordered_set<void*> unique_data_ptrs;
auto params = this->parameters();
unique_data_ptrs.reserve(params.size());
for (const auto& p : params) {
unique_data_ptrs.emplace(p.data_ptr());
}
return unique_data_ptrs.size() != params.size();
}
template class RNNImplBase<LSTMImpl>;
@ -329,275 +215,91 @@ template class RNNImplBase<RNNImpl>;
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
detail::RNNOptionsBase::rnn_options_base_mode_t compute_rnn_options_base_mode(
RNNOptions::nonlinearity_t nonlinearity) {
if (c10::get_if<enumtype::kTanh>(&nonlinearity)) {
return torch::kRNN_TANH;
} else if (c10::get_if<enumtype::kReLU>(&nonlinearity)) {
return torch::kRNN_RELU;
} else {
TORCH_CHECK(false, "Unknown nonlinearity ", torch::enumtype::get_enum_name(nonlinearity));
}
}
RNNImpl::RNNImpl(const RNNOptions& options_)
: detail::RNNImplBase<RNNImpl>(
detail::RNNOptionsBase(
compute_rnn_options_base_mode(options_.nonlinearity()),
options_.input_size(),
options_.hidden_size())
.num_layers(options_.num_layers())
.bias(options_.bias())
.batch_first(options_.batch_first())
detail::RNNOptionsBase(options_.input_size(), options_.hidden_size())
.layers(options_.layers())
.with_bias(options_.with_bias())
.dropout(options_.dropout())
.bidirectional(options_.bidirectional())),
.bidirectional(options_.bidirectional())
.batch_first(options_.batch_first()),
static_cast<CuDNNMode>(options_.activation())),
options(options_) {}
std::tuple<Tensor, Tensor> RNNImpl::forward_helper(
const Tensor& input,
const Tensor& batch_sizes,
const Tensor& sorted_indices,
int64_t max_batch_size,
Tensor hx) {
if (!hx.defined()) {
int64_t num_directions = options_base.bidirectional() ? 2 : 1;
hx = torch::zeros({options_base.num_layers() * num_directions,
max_batch_size, options_base.hidden_size()},
torch::dtype(input.dtype()).device(input.device()));
} else {
// Each batch of the hidden state should match the input sequence that
// the user believes he/she is passing in.
hx = this->permute_hidden(hx, sorted_indices);
}
void RNNImpl::pretty_print(std::ostream& stream) const {
stream << "torch::nn::RNN(input_size=" << options.input_size()
<< ", hidden_size=" << options.hidden_size()
<< ", layers=" << options.layers() << ", dropout=" << options.dropout()
<< ", activation="
<< (options.activation() == RNNActivation::Tanh ? "tanh" : "relu")
<< ")";
}
this->check_forward_args(input, hx, batch_sizes);
std::tuple<Tensor, Tensor> result;
if (!batch_sizes.defined()) {
if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
result = torch::rnn_tanh(input, hx, flat_weights_, options_base.bias(), options_base.num_layers(),
options_base.dropout(), this->is_training(), options_base.bidirectional(), options_base.batch_first());
} else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
result = torch::rnn_relu(input, hx, flat_weights_, options_base.bias(), options_base.num_layers(),
options_base.dropout(), this->is_training(), options_base.bidirectional(), options_base.batch_first());
} else {
TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(options_base.mode()));
}
} else {
if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
result = torch::rnn_tanh(input, batch_sizes, hx, flat_weights_, options_base.bias(),
options_base.num_layers(), options_base.dropout(), this->is_training(), options_base.bidirectional());
} else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
result = torch::rnn_relu(input, batch_sizes, hx, flat_weights_, options_base.bias(),
options_base.num_layers(), options_base.dropout(), this->is_training(), options_base.bidirectional());
} else {
TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(options_base.mode()));
}
RNNOutput RNNImpl::forward(const Tensor& input, Tensor state) {
switch (options.activation()) {
case RNNActivation::ReLU:
return generic_forward(
static_cast<RNNFunctionSignature*>(&torch::rnn_relu),
input,
std::move(state));
case RNNActivation::Tanh:
return generic_forward(
static_cast<RNNFunctionSignature*>(&torch::rnn_tanh),
input,
std::move(state));
default:
AT_ERROR("Unhandled RNN activation function!");
}
auto output = std::get<0>(result);
auto hidden = std::get<1>(result);
return std::make_tuple(output, hidden);
}
std::tuple<Tensor, Tensor> RNNImpl::forward(const Tensor& input, Tensor hx) {
auto batch_sizes = torch::Tensor();
auto max_batch_size = options_base.batch_first() ? input.size(0) : input.size(1);
auto sorted_indices = torch::Tensor();
auto unsorted_indices = torch::Tensor();
Tensor output, hidden;
std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices));
}
std::tuple<PackedSequence, Tensor> RNNImpl::forward_with_packed_input(const PackedSequence& packed_input, Tensor hx) {
const auto& input = packed_input.data();
const auto& batch_sizes = packed_input.batch_sizes();
const auto& sorted_indices = packed_input.sorted_indices();
const auto& unsorted_indices = packed_input.unsorted_indices();
auto max_batch_size = batch_sizes[0].item<int64_t>();
Tensor output, hidden;
std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
LSTMImpl::LSTMImpl(const LSTMOptions& options_)
: detail::RNNImplBase<LSTMImpl>(
detail::RNNOptionsBase(
torch::kLSTM,
options_.input_size(),
options_.hidden_size())
.num_layers(options_.num_layers())
.bias(options_.bias())
.batch_first(options_.batch_first())
.dropout(options_.dropout())
.bidirectional(options_.bidirectional())),
options(options_) {}
options_,
CuDNNMode::LSTM,
/*number_of_gates=*/4) {}
void LSTMImpl::check_forward_args(const Tensor& input, std::tuple<Tensor, Tensor> hidden, const Tensor& batch_sizes) const {
this->check_input(input, batch_sizes);
auto expected_hidden_size = this->get_expected_hidden_size(input, batch_sizes);
this->check_hidden_size(std::get<0>(hidden), expected_hidden_size,
"Expected hidden[0] size {1}, got {2}");
this->check_hidden_size(std::get<1>(hidden), expected_hidden_size,
"Expected hidden[1] size {1}, got {2}");
}
std::tuple<Tensor, Tensor> LSTMImpl::permute_hidden(std::tuple<Tensor, Tensor> hx, const Tensor& permutation) const {
if (!permutation.defined()) {
return hx;
RNNOutput LSTMImpl::forward(const Tensor& input, Tensor state) {
// It would be trickier to adapt the `generic_forward` for the LSTM because
// its output has a different dimensionality (3-tuple vs. 2-tuple), while we
// always return one state variable (stacking the hidden/cell state into one),
// which also makes the state variables going into the `generic_forward`, and
// the way we default-initialize the state when it is not passed, slightly
// different. So we just re-implement it specifically for the LSTM here.
if (!state.defined()) {
// 2 for hidden state and cell state, then #layers, batch size, state size
const auto batch_size = input.size(options.batch_first() ? 0 : 1);
const auto num_directions = options.bidirectional() ? 2 : 1;
state = torch::zeros(
{2, options.layers() * num_directions, batch_size, options.hidden_size()},
input.options());
}
return std::make_tuple(
apply_permutation(std::get<0>(hx), permutation),
apply_permutation(std::get<1>(hx), permutation)
);
}
std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward_helper(
const Tensor& input,
const Tensor& batch_sizes,
const Tensor& sorted_indices,
int64_t max_batch_size,
torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
std::tuple<Tensor, Tensor> hx;
if (!hx_opt.has_value()) {
int64_t num_directions = options.bidirectional() ? 2 : 1;
auto zeros = torch::zeros({options.num_layers() * num_directions,
max_batch_size, options.hidden_size()},
torch::dtype(input.dtype()).device(input.device()));
hx = std::make_tuple(zeros, zeros);
} else {
hx = hx_opt.value();
// Each batch of the hidden state should match the input sequence that
// the user believes he/she is passing in.
hx = this->permute_hidden(hx, sorted_indices);
}
this->check_forward_args(input, hx, batch_sizes);
std::tuple<Tensor, Tensor, Tensor> result;
if (!batch_sizes.defined()) {
result = torch::lstm(input, {std::get<0>(hx), std::get<1>(hx)}, flat_weights_, options.bias(), options.num_layers(),
options.dropout(), this->is_training(), options.bidirectional(), options.batch_first());
} else {
result = torch::lstm(input, batch_sizes, {std::get<0>(hx), std::get<1>(hx)}, flat_weights_, options.bias(),
options.num_layers(), options.dropout(), this->is_training(), options.bidirectional());
}
auto output = std::get<0>(result);
auto hidden = std::make_tuple(std::get<1>(result), std::get<2>(result));
return std::make_tuple(output, hidden);
}
std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward(
const Tensor& input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
auto batch_sizes = torch::Tensor();
auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
auto sorted_indices = torch::Tensor();
auto unsorted_indices = torch::Tensor();
Tensor output;
std::tuple<Tensor, Tensor> hidden;
std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt));
return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices));
}
std::tuple<PackedSequence, std::tuple<Tensor, Tensor>> LSTMImpl::forward_with_packed_input(
const PackedSequence& packed_input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
const auto& input = packed_input.data();
const auto& batch_sizes = packed_input.batch_sizes();
const auto& sorted_indices = packed_input.sorted_indices();
const auto& unsorted_indices = packed_input.unsorted_indices();
auto max_batch_size = batch_sizes[0].item<int64_t>();
Tensor output;
std::tuple<Tensor, Tensor> hidden;
std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt));
auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices));
Tensor output, hidden_state, cell_state;
std::tie(output, hidden_state, cell_state) = torch::lstm(
input,
{state[0], state[1]},
flat_weights_,
options.with_bias(),
options.layers(),
options.dropout(),
this->is_training(),
options.bidirectional(),
options.batch_first());
return {output, torch::stack({hidden_state, cell_state})};
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
GRUImpl::GRUImpl(const GRUOptions& options_)
: detail::RNNImplBase<GRUImpl>(
detail::RNNOptionsBase(
torch::kGRU,
options_.input_size(),
options_.hidden_size())
.num_layers(options_.num_layers())
.bias(options_.bias())
.batch_first(options_.batch_first())
.dropout(options_.dropout())
.bidirectional(options_.bidirectional())),
options(options_) {}
options_,
CuDNNMode::GRU,
/*number_of_gates=*/3) {}
std::tuple<Tensor, Tensor> GRUImpl::forward_helper(
const Tensor& input,
const Tensor& batch_sizes,
const Tensor& sorted_indices,
int64_t max_batch_size,
Tensor hx) {
if (!hx.defined()) {
int64_t num_directions = options.bidirectional() ? 2 : 1;
hx = torch::zeros({options.num_layers() * num_directions,
max_batch_size, options.hidden_size()},
torch::dtype(input.dtype()).device(input.device()));
} else {
// Each batch of the hidden state should match the input sequence that
// the user believes he/she is passing in.
hx = this->permute_hidden(hx, sorted_indices);
}
this->check_forward_args(input, hx, batch_sizes);
std::tuple<Tensor, Tensor> result;
if (!batch_sizes.defined()) {
result = torch::gru(input, hx, flat_weights_, options.bias(), options.num_layers(),
options.dropout(), this->is_training(), options.bidirectional(), options.batch_first());
} else {
result = torch::gru(input, batch_sizes, hx, flat_weights_, options.bias(),
options.num_layers(), options.dropout(), this->is_training(), options.bidirectional());
}
auto output = std::get<0>(result);
auto hidden = std::get<1>(result);
return std::make_tuple(output, hidden);
}
std::tuple<Tensor, Tensor> GRUImpl::forward(const Tensor& input, Tensor hx) {
auto batch_sizes = torch::Tensor();
auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
auto sorted_indices = torch::Tensor();
auto unsorted_indices = torch::Tensor();
Tensor output, hidden;
std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices));
}
std::tuple<PackedSequence, Tensor> GRUImpl::forward_with_packed_input(const PackedSequence& packed_input, Tensor hx) {
const auto& input = packed_input.data();
const auto& batch_sizes = packed_input.batch_sizes();
const auto& sorted_indices = packed_input.sorted_indices();
const auto& unsorted_indices = packed_input.unsorted_indices();
auto max_batch_size = batch_sizes[0].item<int64_t>();
Tensor output, hidden;
std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices));
RNNOutput GRUImpl::forward(const Tensor& input, Tensor state) {
return generic_forward(
static_cast<RNNFunctionSignature*>(&torch::gru), input, std::move(state));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -5,19 +5,21 @@ namespace nn {
namespace detail {
RNNOptionsBase::RNNOptionsBase(rnn_options_base_mode_t mode, int64_t input_size, int64_t hidden_size)
: mode_(mode), input_size_(input_size), hidden_size_(hidden_size) {}
RNNOptionsBase::RNNOptionsBase(int64_t input_size, int64_t hidden_size)
: input_size_(input_size), hidden_size_(hidden_size) {}
} // namespace detail
RNNOptions::RNNOptions(int64_t input_size, int64_t hidden_size)
: input_size_(input_size), hidden_size_(hidden_size) {}
LSTMOptions::LSTMOptions(int64_t input_size, int64_t hidden_size)
: input_size_(input_size), hidden_size_(hidden_size) {}
RNNOptions& RNNOptions::tanh() {
return activation(RNNActivation::Tanh);
}
GRUOptions::GRUOptions(int64_t input_size, int64_t hidden_size)
: input_size_(input_size), hidden_size_(hidden_size) {}
RNNOptions& RNNOptions::relu() {
return activation(RNNActivation::ReLU);
}
namespace detail {