mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[C++ API] RNN / GRU / LSTM layer refactoring (#34322)
Summary: This PR refactors RNN / GRU / LSTM layers in C++ API to exactly match the implementation in Python API. **BC-breaking changes:** - Instead of returning `RNNOutput`, RNN / GRU forward method now returns `std::tuple<Tensor, Tensor>`, and LSTM forward method now returns `std::tuple<Tensor, std::tuple<Tensor, Tensor>>`, matching Python API. - RNN / LSTM / GRU forward method now accepts the same inputs (input tensor and optionally hidden state), matching Python API. - RNN / LSTM / GRU layers now have `forward_with_packed_input` method which accepts `PackedSequence` as input and optionally hidden state, matching the `forward(PackedSequence, ...)` variant in Python API. - RNN / LSTM / GRU layers no longer have these fields: `w_ih` / `w_hh` / `b_ih` / `b_hh`. Instead, to access the weights and biases of the gates, users should do e.g. `rnn->named_parameters()["weight_ih_l0"]`, which mirrors the Python API `rnn.weight_ih_l0`. - In `RNNOptions` - `tanh()` / `relu()` / `activation` are removed. Instead, `nonlinearity` is added which takes either `torch::kTanh` or `torch::kReLU` - `layers` -> `num_layers` - `with_bias` -> `bias` - In `LSTMOptions` - `layers` -> `num_layers` - `with_bias` -> `bias` - In `GRUOptions` - `layers` -> `num_layers` - `with_bias` -> `bias` The majority of the changes in this PR focused on refactoring the implementations in `torch/csrc/api/src/nn/modules/rnn.cpp` to match the Python API. RNN tests are then changed to reflected the revised API design. Pull Request resolved: https://github.com/pytorch/pytorch/pull/34322 Differential Revision: D20458302 Pulled By: yf225 fbshipit-source-id: ffff2ae1ddb1c742c966956f6ad4d7fba03dc54d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d4f182d06b
commit
bdd7dbfd4b
@ -43,7 +43,11 @@ TEST(EnumTest, AllEnums) {
|
||||
torch::enumtype::kBatchMean,
|
||||
torch::enumtype::kZeros,
|
||||
torch::enumtype::kBorder,
|
||||
torch::enumtype::kReflection
|
||||
torch::enumtype::kReflection,
|
||||
torch::enumtype::kRNN_TANH,
|
||||
torch::enumtype::kRNN_RELU,
|
||||
torch::enumtype::kLSTM,
|
||||
torch::enumtype::kGRU
|
||||
> v;
|
||||
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(Linear)
|
||||
@ -76,4 +80,8 @@ TEST(EnumTest, AllEnums) {
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(Zeros)
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(Border)
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(Reflection)
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(RNN_TANH)
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(RNN_RELU)
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(LSTM)
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(GRU)
|
||||
}
|
||||
|
@ -283,7 +283,7 @@ TEST_F(ModuleListTest, PrettyPrintModuleList) {
|
||||
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
||||
" (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
||||
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
||||
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
|
||||
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
|
||||
")");
|
||||
}
|
||||
|
||||
|
@ -23,7 +23,7 @@ bool test_RNN_xor(Func&& model_maker, bool cuda = false) {
|
||||
auto B = x.size(1);
|
||||
x = x.view({T * B, 1});
|
||||
x = l1->forward(x).view({T, B, nhid}).tanh_();
|
||||
x = rnn->forward(x).output[T - 1];
|
||||
x = std::get<0>(rnn->forward(x))[T - 1];
|
||||
x = lo->forward(x);
|
||||
return x;
|
||||
};
|
||||
@ -61,29 +61,39 @@ bool test_RNN_xor(Func&& model_maker, bool cuda = false) {
|
||||
return true;
|
||||
};
|
||||
|
||||
void check_lstm_sizes(RNNOutput output) {
|
||||
void check_lstm_sizes(std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> lstm_output) {
|
||||
// Expect the LSTM to have 64 outputs and 3 layers, with an input of batch
|
||||
// 10 and 16 time steps (10 x 16 x n)
|
||||
|
||||
ASSERT_EQ(output.output.ndimension(), 3);
|
||||
ASSERT_EQ(output.output.size(0), 10);
|
||||
ASSERT_EQ(output.output.size(1), 16);
|
||||
ASSERT_EQ(output.output.size(2), 64);
|
||||
torch::Tensor output = std::get<0>(lstm_output);
|
||||
std::tuple<torch::Tensor, torch::Tensor> state = std::get<1>(lstm_output);
|
||||
torch::Tensor hx = std::get<0>(state);
|
||||
torch::Tensor cx = std::get<1>(state);
|
||||
|
||||
ASSERT_EQ(output.state.ndimension(), 4);
|
||||
ASSERT_EQ(output.state.size(0), 2); // (hx, cx)
|
||||
ASSERT_EQ(output.state.size(1), 3); // layers
|
||||
ASSERT_EQ(output.state.size(2), 16); // Batchsize
|
||||
ASSERT_EQ(output.state.size(3), 64); // 64 hidden dims
|
||||
ASSERT_EQ(output.ndimension(), 3);
|
||||
ASSERT_EQ(output.size(0), 10);
|
||||
ASSERT_EQ(output.size(1), 16);
|
||||
ASSERT_EQ(output.size(2), 64);
|
||||
|
||||
ASSERT_EQ(hx.ndimension(), 3);
|
||||
ASSERT_EQ(hx.size(0), 3); // layers
|
||||
ASSERT_EQ(hx.size(1), 16); // Batchsize
|
||||
ASSERT_EQ(hx.size(2), 64); // 64 hidden dims
|
||||
|
||||
ASSERT_EQ(cx.ndimension(), 3);
|
||||
ASSERT_EQ(cx.size(0), 3); // layers
|
||||
ASSERT_EQ(cx.size(1), 16); // Batchsize
|
||||
ASSERT_EQ(cx.size(2), 64); // 64 hidden dims
|
||||
|
||||
// Something is in the hiddens
|
||||
ASSERT_GT(output.state.norm().item<float>(), 0);
|
||||
ASSERT_GT(hx.norm().item<float>(), 0);
|
||||
ASSERT_GT(cx.norm().item<float>(), 0);
|
||||
}
|
||||
|
||||
struct RNNTest : torch::test::SeedingFixture {};
|
||||
|
||||
TEST_F(RNNTest, CheckOutputSizes) {
|
||||
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
|
||||
LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2));
|
||||
// Input size is: sequence length, batch size, input size
|
||||
auto x = torch::randn({10, 16, 128}, torch::requires_grad());
|
||||
auto output = model->forward(x);
|
||||
@ -92,11 +102,17 @@ TEST_F(RNNTest, CheckOutputSizes) {
|
||||
y.backward();
|
||||
check_lstm_sizes(output);
|
||||
|
||||
auto next = model->forward(x, output.state);
|
||||
auto next = model->forward(x, std::get<1>(output));
|
||||
|
||||
check_lstm_sizes(next);
|
||||
|
||||
torch::Tensor diff = next.state - output.state;
|
||||
auto output_hx = std::get<0>(std::get<1>(output));
|
||||
auto output_cx = std::get<1>(std::get<1>(output));
|
||||
|
||||
auto next_hx = std::get<0>(std::get<1>(next));
|
||||
auto next_cx = std::get<1>(std::get<1>(next));
|
||||
|
||||
torch::Tensor diff = torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0);
|
||||
|
||||
// Hiddens changed
|
||||
ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
|
||||
@ -122,12 +138,12 @@ TEST_F(RNNTest, CheckOutputValuesMatchPyTorch) {
|
||||
}
|
||||
|
||||
auto out = model->forward(x);
|
||||
ASSERT_EQ(out.output.ndimension(), 3);
|
||||
ASSERT_EQ(out.output.size(0), 3);
|
||||
ASSERT_EQ(out.output.size(1), 4);
|
||||
ASSERT_EQ(out.output.size(2), 2);
|
||||
ASSERT_EQ(std::get<0>(out).ndimension(), 3);
|
||||
ASSERT_EQ(std::get<0>(out).size(0), 3);
|
||||
ASSERT_EQ(std::get<0>(out).size(1), 4);
|
||||
ASSERT_EQ(std::get<0>(out).size(2), 2);
|
||||
|
||||
auto flat = out.output.view(3 * 4 * 2);
|
||||
auto flat = std::get<0>(out).view(3 * 4 * 2);
|
||||
float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239,
|
||||
0.4183, 0.5147, 0.6822, 0.8064, 0.6726, 0.7968,
|
||||
0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
|
||||
@ -136,12 +152,20 @@ TEST_F(RNNTest, CheckOutputValuesMatchPyTorch) {
|
||||
ASSERT_LT(std::abs(flat[i].item<float>() - c_out[i]), 1e-3);
|
||||
}
|
||||
|
||||
ASSERT_EQ(out.state.ndimension(), 4); // (hx, cx) x layers x B x 2
|
||||
ASSERT_EQ(out.state.size(0), 2);
|
||||
ASSERT_EQ(out.state.size(1), 1);
|
||||
ASSERT_EQ(out.state.size(2), 4);
|
||||
ASSERT_EQ(out.state.size(3), 2);
|
||||
flat = out.state.view(16);
|
||||
auto hx = std::get<0>(std::get<1>(out));
|
||||
auto cx = std::get<1>(std::get<1>(out));
|
||||
|
||||
ASSERT_EQ(hx.ndimension(), 3); // layers x B x 2
|
||||
ASSERT_EQ(hx.size(0), 1);
|
||||
ASSERT_EQ(hx.size(1), 4);
|
||||
ASSERT_EQ(hx.size(2), 2);
|
||||
|
||||
ASSERT_EQ(cx.ndimension(), 3); // layers x B x 2
|
||||
ASSERT_EQ(cx.size(0), 1);
|
||||
ASSERT_EQ(cx.size(1), 4);
|
||||
ASSERT_EQ(cx.size(2), 2);
|
||||
|
||||
flat = torch::cat({hx, cx}, 0).view(16);
|
||||
float h_out[] = {0.7889,
|
||||
0.9003,
|
||||
0.7769,
|
||||
@ -165,27 +189,27 @@ TEST_F(RNNTest, CheckOutputValuesMatchPyTorch) {
|
||||
|
||||
TEST_F(RNNTest, EndToEndLSTM) {
|
||||
ASSERT_TRUE(test_RNN_xor<LSTM>(
|
||||
[](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }));
|
||||
[](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); }));
|
||||
}
|
||||
|
||||
TEST_F(RNNTest, EndToEndGRU) {
|
||||
ASSERT_TRUE(
|
||||
test_RNN_xor<GRU>([](int s) { return GRU(GRUOptions(s, s).layers(2)); }));
|
||||
test_RNN_xor<GRU>([](int s) { return GRU(GRUOptions(s, s).num_layers(2)); }));
|
||||
}
|
||||
|
||||
TEST_F(RNNTest, EndToEndRNNRelu) {
|
||||
ASSERT_TRUE(test_RNN_xor<RNN>(
|
||||
[](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }));
|
||||
[](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2)); }));
|
||||
}
|
||||
|
||||
TEST_F(RNNTest, EndToEndRNNTanh) {
|
||||
ASSERT_TRUE(test_RNN_xor<RNN>(
|
||||
[](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }));
|
||||
[](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2)); }));
|
||||
}
|
||||
|
||||
TEST_F(RNNTest, Sizes_CUDA) {
|
||||
torch::manual_seed(0);
|
||||
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
|
||||
LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2));
|
||||
model->to(torch::kCUDA);
|
||||
auto x =
|
||||
torch::randn({10, 16, 128}, torch::requires_grad().device(torch::kCUDA));
|
||||
@ -195,11 +219,17 @@ TEST_F(RNNTest, Sizes_CUDA) {
|
||||
y.backward();
|
||||
check_lstm_sizes(output);
|
||||
|
||||
auto next = model->forward(x, output.state);
|
||||
auto next = model->forward(x, std::get<1>(output));
|
||||
|
||||
check_lstm_sizes(next);
|
||||
|
||||
torch::Tensor diff = next.state - output.state;
|
||||
auto output_hx = std::get<0>(std::get<1>(output));
|
||||
auto output_cx = std::get<1>(std::get<1>(output));
|
||||
|
||||
auto next_hx = std::get<0>(std::get<1>(next));
|
||||
auto next_cx = std::get<1>(std::get<1>(next));
|
||||
|
||||
torch::Tensor diff = torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0);
|
||||
|
||||
// Hiddens changed
|
||||
ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
|
||||
@ -207,51 +237,68 @@ TEST_F(RNNTest, Sizes_CUDA) {
|
||||
|
||||
TEST_F(RNNTest, EndToEndLSTM_CUDA) {
|
||||
ASSERT_TRUE(test_RNN_xor<LSTM>(
|
||||
[](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }, true));
|
||||
[](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); }, true));
|
||||
}
|
||||
|
||||
TEST_F(RNNTest, EndToEndGRU_CUDA) {
|
||||
ASSERT_TRUE(test_RNN_xor<GRU>(
|
||||
[](int s) { return GRU(GRUOptions(s, s).layers(2)); }, true));
|
||||
[](int s) { return GRU(GRUOptions(s, s).num_layers(2)); }, true));
|
||||
}
|
||||
|
||||
TEST_F(RNNTest, EndToEndRNNRelu_CUDA) {
|
||||
ASSERT_TRUE(test_RNN_xor<RNN>(
|
||||
[](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }, true));
|
||||
[](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2)); }, true));
|
||||
}
|
||||
TEST_F(RNNTest, EndToEndRNNTanh_CUDA) {
|
||||
ASSERT_TRUE(test_RNN_xor<RNN>(
|
||||
[](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }, true));
|
||||
[](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2)); }, true));
|
||||
}
|
||||
|
||||
TEST_F(RNNTest, PrettyPrintRNNs) {
|
||||
ASSERT_EQ(
|
||||
c10::str(LSTM(LSTMOptions(128, 64).layers(3).dropout(0.2))),
|
||||
"torch::nn::LSTM(input_size=128, hidden_size=64, layers=3, dropout=0.2)");
|
||||
c10::str(LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2))),
|
||||
"torch::nn::LSTM(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)");
|
||||
ASSERT_EQ(
|
||||
c10::str(GRU(GRUOptions(128, 64).layers(3).dropout(0.5))),
|
||||
"torch::nn::GRU(input_size=128, hidden_size=64, layers=3, dropout=0.5)");
|
||||
c10::str(GRU(GRUOptions(128, 64).num_layers(3).dropout(0.5))),
|
||||
"torch::nn::GRU(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.5, bidirectional=false)");
|
||||
ASSERT_EQ(
|
||||
c10::str(RNN(RNNOptions(128, 64).layers(3).dropout(0.2).tanh())),
|
||||
"torch::nn::RNN(input_size=128, hidden_size=64, layers=3, dropout=0.2, activation=tanh)");
|
||||
c10::str(RNN(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh))),
|
||||
"torch::nn::RNN(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)");
|
||||
}
|
||||
|
||||
// This test assures that flatten_parameters does not crash,
|
||||
// when bidirectional is set to true
|
||||
// https://github.com/pytorch/pytorch/issues/19545
|
||||
TEST_F(RNNTest, BidirectionalFlattenParameters) {
|
||||
GRU gru(GRUOptions(100, 256).layers(2).bidirectional(true));
|
||||
GRU gru(GRUOptions(100, 256).num_layers(2).bidirectional(true));
|
||||
gru->flatten_parameters();
|
||||
}
|
||||
|
||||
template <typename Impl>
|
||||
void copyParameters(torch::nn::ModuleHolder<Impl>& target, size_t t_i,
|
||||
const torch::nn::ModuleHolder<Impl>& source, size_t s_i) {
|
||||
void copyParameters(torch::nn::ModuleHolder<Impl>& target, std::string t_suffix,
|
||||
const torch::nn::ModuleHolder<Impl>& source, std::string s_suffix) {
|
||||
at::NoGradGuard guard;
|
||||
target->w_ih[t_i].copy_(source->w_ih[s_i]);
|
||||
target->w_hh[t_i].copy_(source->w_hh[s_i]);
|
||||
target->b_ih[t_i].copy_(source->b_ih[s_i]);
|
||||
target->b_hh[t_i].copy_(source->b_hh[s_i]);
|
||||
target->named_parameters()["weight_ih_l" + t_suffix].copy_(source->named_parameters()["weight_ih_l" + s_suffix]);
|
||||
target->named_parameters()["weight_hh_l" + t_suffix].copy_(source->named_parameters()["weight_hh_l" + s_suffix]);
|
||||
target->named_parameters()["bias_ih_l" + t_suffix].copy_(source->named_parameters()["bias_ih_l" + s_suffix]);
|
||||
target->named_parameters()["bias_hh_l" + t_suffix].copy_(source->named_parameters()["bias_hh_l" + s_suffix]);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> gru_output_to_device(
|
||||
std::tuple<torch::Tensor, torch::Tensor> gru_output, torch::Device device) {
|
||||
return std::make_tuple(
|
||||
std::get<0>(gru_output).to(device),
|
||||
std::get<1>(gru_output).to(device));
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> lstm_output_to_device(
|
||||
std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> lstm_output, torch::Device device) {
|
||||
auto hidden_states = std::get<1>(lstm_output);
|
||||
return std::make_tuple(
|
||||
std::get<0>(lstm_output).to(device),
|
||||
std::make_tuple(
|
||||
std::get<0>(hidden_states).to(device),
|
||||
std::get<1>(hidden_states).to(device)));
|
||||
}
|
||||
|
||||
// This test is a port of python code introduced here:
|
||||
@ -264,7 +311,7 @@ void BidirectionalGRUReverseForward(bool cuda) {
|
||||
auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1});
|
||||
auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1});
|
||||
|
||||
auto gru_options = GRUOptions(1, 1).layers(1).batch_first(false);
|
||||
auto gru_options = GRUOptions(1, 1).num_layers(1).batch_first(false);
|
||||
GRU bi_grus {gru_options.bidirectional(true)};
|
||||
GRU reverse_gru {gru_options.bidirectional(false)};
|
||||
|
||||
@ -275,28 +322,26 @@ void BidirectionalGRUReverseForward(bool cuda) {
|
||||
|
||||
// Now make sure the weights of the reverse gru layer match
|
||||
// ones of the (reversed) bidirectional's:
|
||||
copyParameters(reverse_gru, 0, bi_grus, 1);
|
||||
copyParameters(reverse_gru, "0", bi_grus, "0_reverse");
|
||||
|
||||
auto bi_output = bi_grus->forward(input);
|
||||
auto reverse_output = reverse_gru->forward(input_reversed);
|
||||
|
||||
if (cuda) {
|
||||
bi_output.output = bi_output.output.to(torch::kCPU);
|
||||
bi_output.state = bi_output.state.to(torch::kCPU);
|
||||
reverse_output.output = reverse_output.output.to(torch::kCPU);
|
||||
reverse_output.state = reverse_output.state.to(torch::kCPU);
|
||||
bi_output = gru_output_to_device(bi_output, torch::kCPU);
|
||||
reverse_output = gru_output_to_device(reverse_output, torch::kCPU);
|
||||
}
|
||||
|
||||
ASSERT_EQ(bi_output.output.size(0), reverse_output.output.size(0));
|
||||
auto size = bi_output.output.size(0);
|
||||
ASSERT_EQ(std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0));
|
||||
auto size = std::get<0>(bi_output).size(0);
|
||||
for (int i = 0; i < size; i++) {
|
||||
ASSERT_EQ(bi_output.output[i][0][1].item<float>(),
|
||||
reverse_output.output[size - 1 - i][0][0].item<float>());
|
||||
ASSERT_EQ(std::get<0>(bi_output)[i][0][1].item<float>(),
|
||||
std::get<0>(reverse_output)[size - 1 - i][0][0].item<float>());
|
||||
}
|
||||
// The hidden states of the reversed GRUs sits
|
||||
// in the odd indices in the first dimension.
|
||||
ASSERT_EQ(bi_output.state[1][0][0].item<float>(),
|
||||
reverse_output.state[0][0][0].item<float>());
|
||||
ASSERT_EQ(std::get<1>(bi_output)[1][0][0].item<float>(),
|
||||
std::get<1>(reverse_output)[0][0][0].item<float>());
|
||||
}
|
||||
|
||||
TEST_F(RNNTest, BidirectionalGRUReverseForward) {
|
||||
@ -315,7 +360,7 @@ void BidirectionalLSTMReverseForwardTest(bool cuda) {
|
||||
auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1});
|
||||
auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1});
|
||||
|
||||
auto lstm_opt = GRUOptions(1, 1).layers(1).batch_first(false);
|
||||
auto lstm_opt = LSTMOptions(1, 1).num_layers(1).batch_first(false);
|
||||
|
||||
LSTM bi_lstm {lstm_opt.bidirectional(true)};
|
||||
LSTM reverse_lstm {lstm_opt.bidirectional(false)};
|
||||
@ -327,30 +372,28 @@ void BidirectionalLSTMReverseForwardTest(bool cuda) {
|
||||
|
||||
// Now make sure the weights of the reverse lstm layer match
|
||||
// ones of the (reversed) bidirectional's:
|
||||
copyParameters(reverse_lstm, 0, bi_lstm, 1);
|
||||
copyParameters(reverse_lstm, "0", bi_lstm, "0_reverse");
|
||||
|
||||
auto bi_output = bi_lstm->forward(input);
|
||||
auto reverse_output = reverse_lstm->forward(input_reversed);
|
||||
|
||||
if (cuda) {
|
||||
bi_output.output = bi_output.output.to(torch::kCPU);
|
||||
bi_output.state = bi_output.state.to(torch::kCPU);
|
||||
reverse_output.output = reverse_output.output.to(torch::kCPU);
|
||||
reverse_output.state = reverse_output.state.to(torch::kCPU);
|
||||
bi_output = lstm_output_to_device(bi_output, torch::kCPU);
|
||||
reverse_output = lstm_output_to_device(reverse_output, torch::kCPU);
|
||||
}
|
||||
|
||||
ASSERT_EQ(bi_output.output.size(0), reverse_output.output.size(0));
|
||||
auto size = bi_output.output.size(0);
|
||||
ASSERT_EQ(std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0));
|
||||
auto size = std::get<0>(bi_output).size(0);
|
||||
for (int i = 0; i < size; i++) {
|
||||
ASSERT_EQ(bi_output.output[i][0][1].item<float>(),
|
||||
reverse_output.output[size - 1 - i][0][0].item<float>());
|
||||
ASSERT_EQ(std::get<0>(bi_output)[i][0][1].item<float>(),
|
||||
std::get<0>(reverse_output)[size - 1 - i][0][0].item<float>());
|
||||
}
|
||||
// The hidden states of the reversed LSTM sits
|
||||
// in the odd indices in the first dimension.
|
||||
ASSERT_EQ(bi_output.state[0][1][0][0].item<float>(),
|
||||
reverse_output.state[0][0][0][0].item<float>());
|
||||
ASSERT_EQ(bi_output.state[1][1][0][0].item<float>(),
|
||||
reverse_output.state[1][0][0][0].item<float>());
|
||||
ASSERT_EQ(std::get<0>(std::get<1>(bi_output))[1][0][0].item<float>(),
|
||||
std::get<0>(std::get<1>(reverse_output))[0][0][0].item<float>());
|
||||
ASSERT_EQ(std::get<1>(std::get<1>(bi_output))[1][0][0].item<float>(),
|
||||
std::get<1>(std::get<1>(reverse_output))[0][0][0].item<float>());
|
||||
}
|
||||
|
||||
TEST_F(RNNTest, BidirectionalLSTMReverseForward) {
|
||||
@ -363,19 +406,15 @@ TEST_F(RNNTest, BidirectionalLSTMReverseForward_CUDA) {
|
||||
|
||||
TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) {
|
||||
// Create two GRUs with the same options
|
||||
auto opt = GRUOptions(2, 4).layers(3).batch_first(false).bidirectional(true);
|
||||
auto opt = GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
|
||||
GRU gru_cpu {opt};
|
||||
GRU gru_cuda {opt};
|
||||
|
||||
// Copy weights and biases from CPU GRU to CUDA GRU
|
||||
{
|
||||
at::NoGradGuard guard;
|
||||
const auto num_directions = gru_cpu->options.bidirectional() ? 2 : 1;
|
||||
for (int64_t layer = 0; layer < gru_cpu->options.layers(); layer++) {
|
||||
for (auto direction = 0; direction < num_directions; direction++) {
|
||||
const auto layer_idx = (layer * num_directions) + direction;
|
||||
copyParameters(gru_cuda, layer_idx, gru_cpu, layer_idx);
|
||||
}
|
||||
for (const auto& param : gru_cpu->named_parameters(/*recurse=*/false)) {
|
||||
gru_cuda->named_parameters()[param.key()].copy_(gru_cpu->named_parameters()[param.key()]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -397,20 +436,19 @@ TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) {
|
||||
auto output_cpu = gru_cpu->forward(input_cpu);
|
||||
auto output_cuda = gru_cuda->forward(input_cuda);
|
||||
|
||||
output_cpu.output = output_cpu.output.to(torch::kCPU);
|
||||
output_cpu.state = output_cpu.state.to(torch::kCPU);
|
||||
output_cpu = gru_output_to_device(output_cpu, torch::kCPU);
|
||||
|
||||
// Assert that the output and state are equal on CPU and CUDA
|
||||
ASSERT_EQ(output_cpu.output.dim(), output_cuda.output.dim());
|
||||
for (int i = 0; i < output_cpu.output.dim(); i++) {
|
||||
ASSERT_EQ(output_cpu.output.size(i), output_cuda.output.size(i));
|
||||
ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim());
|
||||
for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) {
|
||||
ASSERT_EQ(std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i));
|
||||
}
|
||||
for (int i = 0; i < output_cpu.output.size(0); i++) {
|
||||
for (int j = 0; j < output_cpu.output.size(1); j++) {
|
||||
for (int k = 0; k < output_cpu.output.size(2); k++) {
|
||||
for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) {
|
||||
for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) {
|
||||
for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) {
|
||||
ASSERT_NEAR(
|
||||
output_cpu.output[i][j][k].item<float>(),
|
||||
output_cuda.output[i][j][k].item<float>(), 1e-5);
|
||||
std::get<0>(output_cpu)[i][j][k].item<float>(),
|
||||
std::get<0>(output_cuda)[i][j][k].item<float>(), 1e-5);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -418,19 +456,15 @@ TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) {
|
||||
|
||||
TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) {
|
||||
// Create two LSTMs with the same options
|
||||
auto opt = LSTMOptions(2, 4).layers(3).batch_first(false).bidirectional(true);
|
||||
auto opt = LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
|
||||
LSTM lstm_cpu {opt};
|
||||
LSTM lstm_cuda {opt};
|
||||
|
||||
// Copy weights and biases from CPU LSTM to CUDA LSTM
|
||||
{
|
||||
at::NoGradGuard guard;
|
||||
const auto num_directions = lstm_cpu->options.bidirectional() ? 2 : 1;
|
||||
for (int64_t layer = 0; layer < lstm_cpu->options.layers(); layer++) {
|
||||
for (auto direction = 0; direction < num_directions; direction++) {
|
||||
const auto layer_idx = (layer * num_directions) + direction;
|
||||
copyParameters(lstm_cuda, layer_idx, lstm_cpu, layer_idx);
|
||||
}
|
||||
for (const auto& param : lstm_cpu->named_parameters(/*recurse=*/false)) {
|
||||
lstm_cuda->named_parameters()[param.key()].copy_(lstm_cpu->named_parameters()[param.key()]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -451,21 +485,68 @@ TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) {
|
||||
auto output_cpu = lstm_cpu->forward(input_cpu);
|
||||
auto output_cuda = lstm_cuda->forward(input_cuda);
|
||||
|
||||
output_cpu.output = output_cpu.output.to(torch::kCPU);
|
||||
output_cpu.state = output_cpu.state.to(torch::kCPU);
|
||||
output_cpu = lstm_output_to_device(output_cpu, torch::kCPU);
|
||||
|
||||
// Assert that the output and state are equal on CPU and CUDA
|
||||
ASSERT_EQ(output_cpu.output.dim(), output_cuda.output.dim());
|
||||
for (int i = 0; i < output_cpu.output.dim(); i++) {
|
||||
ASSERT_EQ(output_cpu.output.size(i), output_cuda.output.size(i));
|
||||
ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim());
|
||||
for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) {
|
||||
ASSERT_EQ(std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i));
|
||||
}
|
||||
for (int i = 0; i < output_cpu.output.size(0); i++) {
|
||||
for (int j = 0; j < output_cpu.output.size(1); j++) {
|
||||
for (int k = 0; k < output_cpu.output.size(2); k++) {
|
||||
for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) {
|
||||
for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) {
|
||||
for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) {
|
||||
ASSERT_NEAR(
|
||||
output_cpu.output[i][j][k].item<float>(),
|
||||
output_cuda.output[i][j][k].item<float>(), 1e-5);
|
||||
std::get<0>(output_cpu)[i][j][k].item<float>(),
|
||||
std::get<0>(output_cuda)[i][j][k].item<float>(), 1e-5);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(RNNTest, UsePackedSequenceAsInput) {
|
||||
{
|
||||
torch::manual_seed(0);
|
||||
auto m = RNN(2, 3);
|
||||
torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
|
||||
auto rnn_output = m->forward_with_packed_input(packed_input);
|
||||
auto expected_output = torch::tensor(
|
||||
{{-0.0645, -0.7274, 0.4531},
|
||||
{-0.3970, -0.6950, 0.6009},
|
||||
{-0.3877, -0.7310, 0.6806}});
|
||||
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
|
||||
|
||||
// Test passing optional argument to `RNN::forward_with_packed_input`
|
||||
rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor());
|
||||
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
|
||||
}
|
||||
{
|
||||
torch::manual_seed(0);
|
||||
auto m = LSTM(2, 3);
|
||||
torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
|
||||
auto rnn_output = m->forward_with_packed_input(packed_input);
|
||||
auto expected_output = torch::tensor(
|
||||
{{-0.2693, -0.1240, 0.0744},
|
||||
{-0.3889, -0.1919, 0.1183},
|
||||
{-0.4425, -0.2314, 0.1386}});
|
||||
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
|
||||
|
||||
// Test passing optional argument to `LSTM::forward_with_packed_input`
|
||||
rnn_output = m->forward_with_packed_input(packed_input, torch::nullopt);
|
||||
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
|
||||
}
|
||||
{
|
||||
torch::manual_seed(0);
|
||||
auto m = GRU(2, 3);
|
||||
torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
|
||||
auto rnn_output = m->forward_with_packed_input(packed_input);
|
||||
auto expected_output = torch::tensor(
|
||||
{{-0.1134, 0.0467, 0.2336},
|
||||
{-0.1189, 0.0502, 0.2960},
|
||||
{-0.1138, 0.0484, 0.3110}});
|
||||
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
|
||||
|
||||
// Test passing optional argument to `GRU::forward_with_packed_input`
|
||||
rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor());
|
||||
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
|
||||
}
|
||||
}
|
||||
|
@ -410,7 +410,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
|
||||
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
||||
" (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
||||
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
||||
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
|
||||
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
|
||||
")");
|
||||
|
||||
Sequential sequential_named({
|
||||
@ -429,7 +429,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
|
||||
" (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
|
||||
" (batchnorm2d): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
|
||||
" (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
|
||||
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
|
||||
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
|
||||
")");
|
||||
}
|
||||
|
||||
@ -598,21 +598,21 @@ TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) {
|
||||
torch::manual_seed(0);
|
||||
Sequential sequential(Identity(), RNN(2, 3));
|
||||
auto x = torch::ones({2, 3, 2});
|
||||
auto rnn_output = sequential->forward<RNNOutput>(x);
|
||||
auto rnn_output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
|
||||
auto expected_output = torch::tensor(
|
||||
{{{0.0000, 0.0000, 0.4886},
|
||||
{0.0000, 0.0000, 0.4886},
|
||||
{0.0000, 0.0000, 0.4886}},
|
||||
{{0.0000, 0.0000, 0.3723},
|
||||
{0.0000, 0.0000, 0.3723},
|
||||
{0.0000, 0.0000, 0.3723}}});
|
||||
ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
|
||||
{{{-0.0645, -0.7274, 0.4531},
|
||||
{-0.0645, -0.7274, 0.4531},
|
||||
{-0.0645, -0.7274, 0.4531}},
|
||||
{{-0.3970, -0.6950, 0.6009},
|
||||
{-0.3970, -0.6950, 0.6009},
|
||||
{-0.3970, -0.6950, 0.6009}}});
|
||||
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
|
||||
}
|
||||
{
|
||||
torch::manual_seed(0);
|
||||
Sequential sequential(Identity(), LSTM(2, 3));
|
||||
auto x = torch::ones({2, 3, 2});
|
||||
auto rnn_output = sequential->forward<RNNOutput>(x);
|
||||
auto rnn_output = sequential->forward<std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>>(x);
|
||||
auto expected_output = torch::tensor(
|
||||
{{{-0.2693, -0.1240, 0.0744},
|
||||
{-0.2693, -0.1240, 0.0744},
|
||||
@ -620,13 +620,13 @@ TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) {
|
||||
{{-0.3889, -0.1919, 0.1183},
|
||||
{-0.3889, -0.1919, 0.1183},
|
||||
{-0.3889, -0.1919, 0.1183}}});
|
||||
ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
|
||||
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
|
||||
}
|
||||
{
|
||||
torch::manual_seed(0);
|
||||
Sequential sequential(Identity(), GRU(2, 3));
|
||||
auto x = torch::ones({2, 3, 2});
|
||||
auto rnn_output = sequential->forward<RNNOutput>(x);
|
||||
auto rnn_output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
|
||||
auto expected_output = torch::tensor(
|
||||
{{{-0.1134, 0.0467, 0.2336},
|
||||
{-0.1134, 0.0467, 0.2336},
|
||||
@ -634,7 +634,7 @@ TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) {
|
||||
{{-0.1189, 0.0502, 0.2960},
|
||||
{-0.1189, 0.0502, 0.2960},
|
||||
{-0.1189, 0.0502, 0.2960}}});
|
||||
ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
|
||||
ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
|
||||
}
|
||||
{
|
||||
torch::manual_seed(0);
|
||||
|
@ -81,9 +81,9 @@ torch::nn::InstanceNorm3d|Yes|No
|
||||
torch::nn::LayerNorm|Yes|No
|
||||
torch::nn::LocalResponseNorm|Yes|No
|
||||
torch::nn::CrossMapLRN2d|Yes|No
|
||||
torch::nn::RNN|No|No
|
||||
torch::nn::LSTM|No|No
|
||||
torch::nn::GRU|No|No
|
||||
torch::nn::RNN|Yes|No
|
||||
torch::nn::LSTM|Yes|No
|
||||
torch::nn::GRU|Yes|No
|
||||
torch::nn::RNNCell|Yes|No
|
||||
torch::nn::LSTMCell|Yes|No
|
||||
torch::nn::GRUCell|Yes|No
|
||||
|
@ -64,7 +64,7 @@ std::string operator()(const enumtype::k##name& v) const { \
|
||||
// However, it throws the following error instead:
|
||||
//
|
||||
// ```
|
||||
// error: could not convert ‘torch::kNone’ from ‘const torch::enumtype::kNone’ to ‘torch::nn::SomeOptions’
|
||||
// error: could not convert `torch::kNone` from `const torch::enumtype::kNone` to `torch::nn::SomeOptions`
|
||||
// ```
|
||||
//
|
||||
// To get around this problem, we explicitly provide the following constructors for `SomeOptions`:
|
||||
@ -122,6 +122,10 @@ TORCH_ENUM_DECLARE(BatchMean)
|
||||
TORCH_ENUM_DECLARE(Zeros)
|
||||
TORCH_ENUM_DECLARE(Border)
|
||||
TORCH_ENUM_DECLARE(Reflection)
|
||||
TORCH_ENUM_DECLARE(RNN_TANH)
|
||||
TORCH_ENUM_DECLARE(RNN_RELU)
|
||||
TORCH_ENUM_DECLARE(LSTM)
|
||||
TORCH_ENUM_DECLARE(GRU)
|
||||
|
||||
namespace torch {
|
||||
namespace enumtype {
|
||||
@ -157,6 +161,10 @@ struct _compute_enum_name {
|
||||
TORCH_ENUM_PRETTY_PRINT(Zeros)
|
||||
TORCH_ENUM_PRETTY_PRINT(Border)
|
||||
TORCH_ENUM_PRETTY_PRINT(Reflection)
|
||||
TORCH_ENUM_PRETTY_PRINT(RNN_TANH)
|
||||
TORCH_ENUM_PRETTY_PRINT(RNN_RELU)
|
||||
TORCH_ENUM_PRETTY_PRINT(LSTM)
|
||||
TORCH_ENUM_PRETTY_PRINT(GRU)
|
||||
};
|
||||
|
||||
template <typename V>
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <torch/nn/modules/common.h>
|
||||
#include <torch/nn/modules/dropout.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/nn/utils/rnn.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
@ -15,36 +16,23 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
using namespace torch::nn::utils::rnn;
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
/// The output of a single invocation of an RNN module's `forward()` method.
|
||||
struct TORCH_API RNNOutput {
|
||||
/// The result of applying the specific RNN algorithm
|
||||
/// to the input tensor and input state.
|
||||
Tensor output;
|
||||
/// The new, updated state that can be fed into the RNN
|
||||
/// in the next forward step.
|
||||
Tensor state;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
/// Base class for all RNN implementations (intended for code sharing).
|
||||
template <typename Derived>
|
||||
class TORCH_API RNNImplBase : public torch::nn::Cloneable<Derived> {
|
||||
public:
|
||||
/// These must line up with the CUDNN mode codes:
|
||||
/// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t
|
||||
enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
|
||||
|
||||
explicit RNNImplBase(
|
||||
const RNNOptionsBase& options_,
|
||||
optional<CuDNNMode> cudnn_mode = nullopt,
|
||||
int64_t number_of_gates = 1);
|
||||
explicit RNNImplBase(const RNNOptionsBase& options_);
|
||||
|
||||
/// Initializes the parameters of the RNN module.
|
||||
void reset() override;
|
||||
|
||||
void reset_parameters();
|
||||
|
||||
/// Overrides `nn::Module::to()` to call `flatten_parameters()` after the
|
||||
/// original operation.
|
||||
void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false)
|
||||
@ -65,52 +53,32 @@ class TORCH_API RNNImplBase : public torch::nn::Cloneable<Derived> {
|
||||
/// called once upon construction, inside `reset()`.
|
||||
void flatten_parameters();
|
||||
|
||||
/// The RNN's options.
|
||||
RNNOptionsBase options;
|
||||
std::vector<Tensor> all_weights() const;
|
||||
|
||||
/// The weights for `input x hidden` gates.
|
||||
std::vector<Tensor> w_ih;
|
||||
/// The weights for `hidden x hidden` gates.
|
||||
std::vector<Tensor> w_hh;
|
||||
/// The biases for `input x hidden` gates.
|
||||
std::vector<Tensor> b_ih;
|
||||
/// The biases for `hidden x hidden` gates.
|
||||
std::vector<Tensor> b_hh;
|
||||
/// The RNN's options.
|
||||
RNNOptionsBase options_base;
|
||||
|
||||
protected:
|
||||
/// The function signature of `rnn_relu`, `rnn_tanh` and `gru`.
|
||||
using RNNFunctionSignature = std::tuple<Tensor, Tensor>(
|
||||
/*input=*/const Tensor&,
|
||||
/*state=*/const Tensor&,
|
||||
/*params=*/TensorList,
|
||||
/*has_biases=*/bool,
|
||||
/*layers=*/int64_t,
|
||||
/*dropout=*/double,
|
||||
/*train=*/bool,
|
||||
/*bidirectional=*/bool,
|
||||
/*batch_first=*/bool);
|
||||
// Resets flat_weights_
|
||||
// Note: be v. careful before removing this, as 3rd party device types
|
||||
// likely rely on this behavior to properly .to() modules like LSTM.
|
||||
void reset_flat_weights();
|
||||
|
||||
/// A generic `forward()` used for RNN and GRU (but not LSTM!). Takes the ATen
|
||||
/// RNN function as first argument.
|
||||
RNNOutput generic_forward(
|
||||
std::function<RNNFunctionSignature> function,
|
||||
const Tensor& input,
|
||||
Tensor state);
|
||||
void check_input(const Tensor& input, const Tensor& batch_sizes) const;
|
||||
|
||||
/// Returns a flat vector of all weights, with layer weights following each
|
||||
/// other sequentially in (w_ih, w_hh, b_ih, b_hh) order.
|
||||
std::vector<Tensor> flat_weights() const;
|
||||
std::tuple<int64_t, int64_t, int64_t> get_expected_hidden_size(const Tensor& input, const Tensor& batch_sizes) const;
|
||||
|
||||
/// Very simple check if any of the parameters (weights, biases) are the same.
|
||||
bool any_parameters_alias() const;
|
||||
void check_hidden_size(
|
||||
const Tensor& hx,
|
||||
std::tuple<int64_t, int64_t, int64_t> expected_hidden_size,
|
||||
std::string msg = "Expected hidden size {1}, got {2}") const;
|
||||
|
||||
/// The number of gate weights/biases required by the RNN subclass.
|
||||
int64_t number_of_gates_;
|
||||
void check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) const;
|
||||
|
||||
/// The cuDNN RNN mode, if this RNN subclass has any.
|
||||
optional<CuDNNMode> cudnn_mode_;
|
||||
Tensor permute_hidden(Tensor hx, const Tensor& permutation) const;
|
||||
|
||||
/// The cached result of the latest `flat_weights()` call.
|
||||
std::vector<std::string> flat_weights_names_;
|
||||
std::vector<std::vector<std::string>> all_weights_;
|
||||
std::vector<Tensor> flat_weights_;
|
||||
};
|
||||
} // namespace detail
|
||||
@ -118,84 +86,139 @@ class TORCH_API RNNImplBase : public torch::nn::Cloneable<Derived> {
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// A multi-layer Elman RNN module with Tanh or ReLU activation.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.RNN to learn about the
|
||||
/// exact behavior of this module.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.RNN to learn
|
||||
/// about the exact behavior of this module.
|
||||
///
|
||||
/// See the documentation for `torch::nn::RNNOptions` class to learn what
|
||||
/// constructor arguments are supported for this module.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// RNN model(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh));
|
||||
/// ```
|
||||
class TORCH_API RNNImpl : public detail::RNNImplBase<RNNImpl> {
|
||||
public:
|
||||
RNNImpl(int64_t input_size, int64_t hidden_size)
|
||||
: RNNImpl(RNNOptions(input_size, hidden_size)) {}
|
||||
explicit RNNImpl(const RNNOptions& options_);
|
||||
|
||||
/// Pretty prints the `RNN` module into the given `stream`.
|
||||
void pretty_print(std::ostream& stream) const override;
|
||||
|
||||
/// Applies the `RNN` module to an input sequence and input state.
|
||||
/// The `input` should follow a `(sequence, batch, features)` layout unless
|
||||
/// `batch_first` is true, in which case the layout should be `(batch,
|
||||
/// sequence, features)`.
|
||||
RNNOutput forward(const Tensor& input, Tensor state = {});
|
||||
std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {});
|
||||
protected:
|
||||
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
|
||||
|
||||
public:
|
||||
std::tuple<PackedSequence, Tensor> forward_with_packed_input(const PackedSequence& packed_input, Tensor hx = {});
|
||||
|
||||
RNNOptions options;
|
||||
|
||||
protected:
|
||||
std::tuple<Tensor, Tensor> forward_helper(
|
||||
const Tensor& input,
|
||||
const Tensor& batch_sizes,
|
||||
const Tensor& sorted_indices,
|
||||
int64_t max_batch_size,
|
||||
Tensor hx);
|
||||
};
|
||||
|
||||
/// A `ModuleHolder` subclass for `RNNImpl`.
|
||||
/// See the documentation for `RNNImpl` class to learn what methods it provides,
|
||||
/// or the documentation for `ModuleHolder` to learn about PyTorch's module
|
||||
/// storage semantics.
|
||||
/// See the documentation for `RNNImpl` class to learn what methods it
|
||||
/// provides, and examples of how to use `RNN` with `torch::nn::RNNOptions`.
|
||||
/// See the documentation for `ModuleHolder` to learn about PyTorch's
|
||||
/// module storage semantics.
|
||||
TORCH_MODULE(RNN);
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// A multi-layer long-short-term-memory (LSTM) module.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.LSTM to learn about the
|
||||
/// exact behavior of this module.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.LSTM to learn
|
||||
/// about the exact behavior of this module.
|
||||
///
|
||||
/// See the documentation for `torch::nn::LSTMOptions` class to learn what
|
||||
/// constructor arguments are supported for this module.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// LSTM model(LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
|
||||
/// ```
|
||||
class TORCH_API LSTMImpl : public detail::RNNImplBase<LSTMImpl> {
|
||||
public:
|
||||
LSTMImpl(int64_t input_size, int64_t hidden_size)
|
||||
: LSTMImpl(LSTMOptions(input_size, hidden_size)) {}
|
||||
explicit LSTMImpl(const LSTMOptions& options_);
|
||||
|
||||
/// Applies the `LSTM` module to an input sequence and input state.
|
||||
/// The `input` should follow a `(sequence, batch, features)` layout unless
|
||||
/// `batch_first` is true, in which case the layout should be `(batch,
|
||||
/// sequence, features)`.
|
||||
RNNOutput forward(const Tensor& input, Tensor state = {});
|
||||
std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward(
|
||||
const Tensor& input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
|
||||
protected:
|
||||
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
|
||||
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())})
|
||||
|
||||
public:
|
||||
std::tuple<PackedSequence, std::tuple<Tensor, Tensor>> forward_with_packed_input(
|
||||
const PackedSequence& packed_input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
|
||||
|
||||
LSTMOptions options;
|
||||
|
||||
protected:
|
||||
void check_forward_args(const Tensor& input, std::tuple<Tensor, Tensor> hidden, const Tensor& batch_sizes) const;
|
||||
|
||||
std::tuple<Tensor, Tensor> permute_hidden(std::tuple<Tensor, Tensor> hx, const Tensor& permutation) const;
|
||||
|
||||
std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward_helper(
|
||||
const Tensor& input,
|
||||
const Tensor& batch_sizes,
|
||||
const Tensor& sorted_indices,
|
||||
int64_t max_batch_size,
|
||||
torch::optional<std::tuple<Tensor, Tensor>> hx_opt);
|
||||
};
|
||||
|
||||
/// A `ModuleHolder` subclass for `LSTMImpl`.
|
||||
/// See the documentation for `LSTMImpl` class to learn what methods it
|
||||
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
|
||||
/// provides, and examples of how to use `LSTM` with `torch::nn::LSTMOptions`.
|
||||
/// See the documentation for `ModuleHolder` to learn about PyTorch's
|
||||
/// module storage semantics.
|
||||
TORCH_MODULE(LSTM);
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// A multi-layer gated recurrent unit (GRU) module.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.GRU to learn about the
|
||||
/// exact behavior of this module.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.GRU to learn
|
||||
/// about the exact behavior of this module.
|
||||
///
|
||||
/// See the documentation for `torch::nn::GRUOptions` class to learn what
|
||||
/// constructor arguments are supported for this module.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// GRU model(GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
|
||||
/// ```
|
||||
class TORCH_API GRUImpl : public detail::RNNImplBase<GRUImpl> {
|
||||
public:
|
||||
GRUImpl(int64_t input_size, int64_t hidden_size)
|
||||
: GRUImpl(GRUOptions(input_size, hidden_size)) {}
|
||||
explicit GRUImpl(const GRUOptions& options_);
|
||||
|
||||
/// Applies the `GRU` module to an input sequence and input state.
|
||||
/// The `input` should follow a `(sequence, batch, features)` layout unless
|
||||
/// `batch_first` is true, in which case the layout should be `(batch,
|
||||
/// sequence, features)`.
|
||||
RNNOutput forward(const Tensor& input, Tensor state = {});
|
||||
std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {});
|
||||
protected:
|
||||
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
|
||||
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::Tensor())})
|
||||
|
||||
public:
|
||||
std::tuple<PackedSequence, Tensor> forward_with_packed_input(const PackedSequence& packed_input, Tensor hx = {});
|
||||
|
||||
GRUOptions options;
|
||||
|
||||
protected:
|
||||
std::tuple<Tensor, Tensor> forward_helper(
|
||||
const Tensor& input,
|
||||
const Tensor& batch_sizes,
|
||||
const Tensor& sorted_indices,
|
||||
int64_t max_batch_size,
|
||||
Tensor hx);
|
||||
};
|
||||
|
||||
/// A `ModuleHolder` subclass for `GRUImpl`.
|
||||
/// See the documentation for `GRUImpl` class to learn what methods it provides,
|
||||
/// or the documentation for `ModuleHolder` to learn about PyTorch's module
|
||||
/// storage semantics.
|
||||
/// See the documentation for `GRUImpl` class to learn what methods it
|
||||
/// provides, and examples of how to use `GRU` with `torch::nn::GRUOptions`.
|
||||
/// See the documentation for `ModuleHolder` to learn about PyTorch's
|
||||
/// module storage semantics.
|
||||
TORCH_MODULE(GRU);
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -10,65 +10,137 @@ namespace nn {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Common options for LSTM and GRU modules.
|
||||
/// Common options for RNN, LSTM and GRU modules.
|
||||
struct TORCH_API RNNOptionsBase {
|
||||
RNNOptionsBase(int64_t input_size, int64_t hidden_size);
|
||||
virtual ~RNNOptionsBase() = default;
|
||||
typedef c10::variant<
|
||||
enumtype::kLSTM,
|
||||
enumtype::kGRU,
|
||||
enumtype::kRNN_TANH,
|
||||
enumtype::kRNN_RELU> rnn_options_base_mode_t;
|
||||
|
||||
RNNOptionsBase(rnn_options_base_mode_t mode, int64_t input_size, int64_t hidden_size);
|
||||
|
||||
TORCH_ARG(rnn_options_base_mode_t, mode);
|
||||
/// The number of features of a single sample in the input sequence `x`.
|
||||
TORCH_ARG(int64_t, input_size);
|
||||
/// The number of features in the hidden state `h`.
|
||||
TORCH_ARG(int64_t, hidden_size);
|
||||
/// The number of recurrent layers (cells) to use.
|
||||
TORCH_ARG(int64_t, layers) = 1;
|
||||
TORCH_ARG(int64_t, num_layers) = 1;
|
||||
/// Whether a bias term should be added to all linear operations.
|
||||
TORCH_ARG(bool, with_bias) = true;
|
||||
TORCH_ARG(bool, bias) = true;
|
||||
/// If true, the input sequence should be provided as `(batch, sequence,
|
||||
/// features)`. If false (default), the expected layout is `(sequence, batch,
|
||||
/// features)`.
|
||||
TORCH_ARG(bool, batch_first) = false;
|
||||
/// If non-zero, adds dropout with the given probability to the output of each
|
||||
/// RNN layer, except the final layer.
|
||||
TORCH_ARG(double, dropout) = 0.0;
|
||||
/// Whether to make the RNN bidirectional.
|
||||
TORCH_ARG(bool, bidirectional) = false;
|
||||
/// If true, the input sequence should be provided as `(batch, sequence,
|
||||
/// features)`. If false (default), the expected layout is `(sequence, batch,
|
||||
/// features)`.
|
||||
TORCH_ARG(bool, batch_first) = false;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
enum class RNNActivation : uint32_t {ReLU, Tanh};
|
||||
|
||||
/// Options for RNN modules.
|
||||
/// Options for the `RNN` module.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// RNN model(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh));
|
||||
/// ```
|
||||
struct TORCH_API RNNOptions {
|
||||
typedef c10::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
|
||||
|
||||
RNNOptions(int64_t input_size, int64_t hidden_size);
|
||||
|
||||
/// Sets the activation after linear operations to `tanh`.
|
||||
RNNOptions& tanh();
|
||||
/// Sets the activation after linear operations to `relu`.
|
||||
RNNOptions& relu();
|
||||
|
||||
/// The number of features of a single sample in the input sequence `x`.
|
||||
/// The number of expected features in the input `x`
|
||||
TORCH_ARG(int64_t, input_size);
|
||||
/// The number of features in the hidden state `h`.
|
||||
/// The number of features in the hidden state `h`
|
||||
TORCH_ARG(int64_t, hidden_size);
|
||||
/// The number of recurrent layers (cells) to use.
|
||||
TORCH_ARG(int64_t, layers) = 1;
|
||||
/// Whether a bias term should be added to all linear operations.
|
||||
TORCH_ARG(bool, with_bias) = true;
|
||||
/// If non-zero, adds dropout with the given probability to the output of each
|
||||
/// RNN layer, except the final layer.
|
||||
TORCH_ARG(double, dropout) = 0.0;
|
||||
/// Whether to make the RNN bidirectional.
|
||||
TORCH_ARG(bool, bidirectional) = false;
|
||||
/// If true, the input sequence should be provided as `(batch, sequence,
|
||||
/// features)`. If false (default), the expected layout is `(sequence, batch,
|
||||
/// features)`.
|
||||
/// Number of recurrent layers. E.g., setting ``num_layers=2``
|
||||
/// would mean stacking two RNNs together to form a `stacked RNN`,
|
||||
/// with the second RNN taking in outputs of the first RNN and
|
||||
/// computing the final results. Default: 1
|
||||
TORCH_ARG(int64_t, num_layers) = 1;
|
||||
/// The non-linearity to use. Can be either ``torch::kTanh`` or ``torch::kReLU``. Default: ``torch::kTanh``
|
||||
TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh;
|
||||
/// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
|
||||
/// Default: ``true``
|
||||
TORCH_ARG(bool, bias) = true;
|
||||
/// If ``true``, then the input and output tensors are provided
|
||||
/// as `(batch, seq, feature)`. Default: ``false``
|
||||
TORCH_ARG(bool, batch_first) = false;
|
||||
/// The activation to use after linear operations.
|
||||
TORCH_ARG(RNNActivation, activation) = RNNActivation::ReLU;
|
||||
/// If non-zero, introduces a `Dropout` layer on the outputs of each
|
||||
/// RNN layer except the last layer, with dropout probability equal to
|
||||
/// `dropout`. Default: 0
|
||||
TORCH_ARG(double, dropout) = 0.0;
|
||||
/// If ``true``, becomes a bidirectional RNN. Default: ``false``
|
||||
TORCH_ARG(bool, bidirectional) = false;
|
||||
};
|
||||
|
||||
using LSTMOptions = detail::RNNOptionsBase;
|
||||
using GRUOptions = detail::RNNOptionsBase;
|
||||
/// Options for the `LSTM` module.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// LSTM model(LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
|
||||
/// ```
|
||||
struct TORCH_API LSTMOptions {
|
||||
LSTMOptions(int64_t input_size, int64_t hidden_size);
|
||||
|
||||
/// The number of expected features in the input `x`
|
||||
TORCH_ARG(int64_t, input_size);
|
||||
/// The number of features in the hidden state `h`
|
||||
TORCH_ARG(int64_t, hidden_size);
|
||||
/// Number of recurrent layers. E.g., setting ``num_layers=2``
|
||||
/// would mean stacking two LSTMs together to form a `stacked LSTM`,
|
||||
/// with the second LSTM taking in outputs of the first LSTM and
|
||||
/// computing the final results. Default: 1
|
||||
TORCH_ARG(int64_t, num_layers) = 1;
|
||||
/// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
|
||||
/// Default: ``true``
|
||||
TORCH_ARG(bool, bias) = true;
|
||||
/// If ``true``, then the input and output tensors are provided
|
||||
/// as (batch, seq, feature). Default: ``false``
|
||||
TORCH_ARG(bool, batch_first) = false;
|
||||
/// If non-zero, introduces a `Dropout` layer on the outputs of each
|
||||
/// LSTM layer except the last layer, with dropout probability equal to
|
||||
/// `dropout`. Default: 0
|
||||
TORCH_ARG(double, dropout) = 0.0;
|
||||
/// If ``true``, becomes a bidirectional LSTM. Default: ``false``
|
||||
TORCH_ARG(bool, bidirectional) = false;
|
||||
};
|
||||
|
||||
/// Options for the `GRU` module.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// GRU model(GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
|
||||
/// ```
|
||||
struct TORCH_API GRUOptions {
|
||||
GRUOptions(int64_t input_size, int64_t hidden_size);
|
||||
|
||||
/// The number of expected features in the input `x`
|
||||
TORCH_ARG(int64_t, input_size);
|
||||
/// The number of features in the hidden state `h`
|
||||
TORCH_ARG(int64_t, hidden_size);
|
||||
/// Number of recurrent layers. E.g., setting ``num_layers=2``
|
||||
/// would mean stacking two GRUs together to form a `stacked GRU`,
|
||||
/// with the second GRU taking in outputs of the first GRU and
|
||||
/// computing the final results. Default: 1
|
||||
TORCH_ARG(int64_t, num_layers) = 1;
|
||||
/// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
|
||||
/// Default: ``true``
|
||||
TORCH_ARG(bool, bias) = true;
|
||||
/// If ``true``, then the input and output tensors are provided
|
||||
/// as (batch, seq, feature). Default: ``false``
|
||||
TORCH_ARG(bool, batch_first) = false;
|
||||
/// If non-zero, introduces a `Dropout` layer on the outputs of each
|
||||
/// GRU layer except the last layer, with dropout probability equal to
|
||||
/// `dropout`. Default: 0
|
||||
TORCH_ARG(double, dropout) = 0.0;
|
||||
/// If ``true``, becomes a bidirectional GRU. Default: ``false``
|
||||
TORCH_ARG(bool, bidirectional) = false;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
|
@ -30,3 +30,7 @@ TORCH_ENUM_DEFINE(BatchMean)
|
||||
TORCH_ENUM_DEFINE(Zeros)
|
||||
TORCH_ENUM_DEFINE(Border)
|
||||
TORCH_ENUM_DEFINE(Reflection)
|
||||
TORCH_ENUM_DEFINE(RNN_TANH)
|
||||
TORCH_ENUM_DEFINE(RNN_RELU)
|
||||
TORCH_ENUM_DEFINE(LSTM)
|
||||
TORCH_ENUM_DEFINE(GRU)
|
||||
|
@ -1,6 +1,5 @@
|
||||
#include <torch/nn/modules/rnn.h>
|
||||
|
||||
#include <torch/nn/modules/dropout.h>
|
||||
#include <torch/nn/init.h>
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
@ -19,65 +18,193 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
using namespace torch::nn::utils::rnn;
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
/// These must line up with the CUDNN mode codes:
|
||||
/// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t
|
||||
enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
|
||||
|
||||
CuDNNMode get_cudnn_mode_for_rnn(detail::RNNOptionsBase::rnn_options_base_mode_t mode) {
|
||||
if (c10::get_if<enumtype::kRNN_RELU>(&mode)) {
|
||||
return CuDNNMode::RNN_RELU;
|
||||
} else if (c10::get_if<enumtype::kRNN_TANH>(&mode)) {
|
||||
return CuDNNMode::RNN_TANH;
|
||||
} else if (c10::get_if<enumtype::kLSTM>(&mode)) {
|
||||
return CuDNNMode::LSTM;
|
||||
} else if (c10::get_if<enumtype::kGRU>(&mode)) {
|
||||
return CuDNNMode::GRU;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(mode));
|
||||
}
|
||||
}
|
||||
|
||||
Tensor apply_permutation(const Tensor& tensor, const Tensor& permutation, int64_t dim = 1) {
|
||||
return tensor.index_select(dim, permutation);
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
namespace detail {
|
||||
template <typename Derived>
|
||||
RNNImplBase<Derived>::RNNImplBase(
|
||||
const RNNOptionsBase& options_,
|
||||
optional<CuDNNMode> cudnn_mode,
|
||||
int64_t number_of_gates)
|
||||
: options(options_),
|
||||
number_of_gates_(number_of_gates),
|
||||
cudnn_mode_(std::move(cudnn_mode)) {
|
||||
RNNImplBase<Derived>::RNNImplBase(const RNNOptionsBase& options_)
|
||||
: options_base(options_) {
|
||||
reset();
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
void RNNImplBase<Derived>::reset() {
|
||||
const auto num_directions = options.bidirectional() ? 2 : 1;
|
||||
const int64_t num_directions = options_base.bidirectional() ? 2 : 1;
|
||||
|
||||
w_ih.resize(options.layers() * num_directions);
|
||||
w_hh.resize(options.layers() * num_directions);
|
||||
b_ih.resize(options.layers() * num_directions);
|
||||
b_hh.resize(options.layers() * num_directions);
|
||||
TORCH_CHECK(
|
||||
0 <= options_base.dropout() && options_base.dropout() <= 1,
|
||||
"dropout should be a number in range [0, 1] ",
|
||||
"representing the probability of an element being ",
|
||||
"zeroed");
|
||||
|
||||
const int64_t gate_size = options.hidden_size() * number_of_gates_;
|
||||
if (options_base.dropout() > 0 && options_base.num_layers() == 1) {
|
||||
TORCH_WARN(
|
||||
"dropout option adds dropout after all but last ",
|
||||
"recurrent layer, so non-zero dropout expects ",
|
||||
"num_layers greater than 1, but got dropout=", options_base.dropout(), " and ",
|
||||
"num_layers=", options_base.num_layers());
|
||||
}
|
||||
|
||||
for (int64_t layer = 0; layer < options.layers(); ++layer) {
|
||||
for (auto direction = 0; direction < num_directions; direction++) {
|
||||
const auto layer_input_size = layer == 0 ? options.input_size() :
|
||||
options.hidden_size() * num_directions;
|
||||
const auto suffix = direction == 1 ? "_reverse" : "";
|
||||
const auto layer_idx = (layer * num_directions) + direction;
|
||||
w_ih[layer_idx] = this->register_parameter(
|
||||
"weight_ih_l" + std::to_string(layer) + suffix,
|
||||
torch::empty({gate_size, layer_input_size}));
|
||||
w_hh[layer_idx] = this->register_parameter(
|
||||
"weight_hh_l" + std::to_string(layer) + suffix,
|
||||
torch::empty({gate_size, options.hidden_size()}));
|
||||
int64_t gate_size = 0;
|
||||
if (c10::get_if<enumtype::kLSTM>(&options_base.mode())) {
|
||||
gate_size = 4 * options_base.hidden_size();
|
||||
} else if (c10::get_if<enumtype::kGRU>(&options_base.mode())) {
|
||||
gate_size = 3 * options_base.hidden_size();
|
||||
} else if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
|
||||
gate_size = options_base.hidden_size();
|
||||
} else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
|
||||
gate_size = options_base.hidden_size();
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unrecognized RNN mode: " + torch::enumtype::get_enum_name(options_base.mode()));
|
||||
}
|
||||
|
||||
if (options.with_bias()) {
|
||||
b_ih[layer_idx] = this->register_parameter(
|
||||
"bias_ih_l" + std::to_string(layer) + suffix,
|
||||
torch::empty({gate_size}));
|
||||
b_hh[layer_idx] = this->register_parameter(
|
||||
"bias_hh_l" + std::to_string(layer) + suffix,
|
||||
torch::empty({gate_size}));
|
||||
flat_weights_names_ = {};
|
||||
all_weights_ = {};
|
||||
|
||||
for (int64_t layer = 0; layer < options_base.num_layers(); layer++) {
|
||||
for (int64_t direction = 0; direction < num_directions; direction++) {
|
||||
int64_t layer_input_size = layer == 0 ? options_base.input_size() : options_base.hidden_size() * num_directions;
|
||||
|
||||
auto w_ih = torch::empty({gate_size, layer_input_size});
|
||||
auto w_hh = torch::empty({gate_size, options_base.hidden_size()});
|
||||
auto b_ih = torch::empty({gate_size});
|
||||
// Second bias vector included for CuDNN compatibility. Only one
|
||||
// bias vector is needed in standard definition.
|
||||
auto b_hh = torch::empty({gate_size});
|
||||
std::vector<Tensor> layer_params = {w_ih, w_hh, b_ih, b_hh};
|
||||
|
||||
std::string suffix = direction == 1 ? "_reverse" : "";
|
||||
std::vector<std::string> param_names = {"weight_ih_l{layer}{suffix}", "weight_hh_l{layer}{suffix}"};
|
||||
if (options_base.bias()) {
|
||||
param_names.emplace_back("bias_ih_l{layer}{suffix}");
|
||||
param_names.emplace_back("bias_hh_l{layer}{suffix}");
|
||||
}
|
||||
for (size_t i = 0; i < param_names.size(); i++) { // NOLINT(modernize-loop-convert)
|
||||
std::string x = std::regex_replace(param_names[i], std::regex("\\{layer\\}"), c10::str(layer));
|
||||
x = std::regex_replace(x, std::regex("\\{suffix\\}"), c10::str(suffix));
|
||||
param_names[i] = x;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < param_names.size(); i++) {
|
||||
auto name = param_names[i];
|
||||
auto param = layer_params[i];
|
||||
this->register_parameter(name, param);
|
||||
}
|
||||
flat_weights_names_.insert(flat_weights_names_.end(), param_names.begin(), param_names.end());
|
||||
all_weights_.emplace_back(param_names);
|
||||
}
|
||||
}
|
||||
|
||||
flat_weights_ = {};
|
||||
for (const auto& wn : flat_weights_names_) {
|
||||
auto named_parameters = this->named_parameters(/*recurse=*/false);
|
||||
if (named_parameters.contains(wn)) {
|
||||
flat_weights_.emplace_back(named_parameters[wn]);
|
||||
} else {
|
||||
flat_weights_.emplace_back(Tensor());
|
||||
}
|
||||
}
|
||||
|
||||
this->flatten_parameters();
|
||||
this->reset_parameters();
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
void RNNImplBase<Derived>::flatten_parameters() {
|
||||
// Resets parameter data pointer so that they can use faster code paths.
|
||||
//
|
||||
// Right now, this works only if the module is on the GPU and cuDNN is enabled.
|
||||
// Otherwise, it's a no-op.
|
||||
|
||||
// Short-circuits if flat_weights_ is only partially instantiated
|
||||
if (flat_weights_.size() != flat_weights_names_.size()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Short-circuits if any tensor in self.flat_weights_ is not acceptable to cuDNN
|
||||
// or the tensors in flat_weights_ are of different dtypes
|
||||
|
||||
auto first_fw = flat_weights_[0];
|
||||
auto dtype = first_fw.dtype();
|
||||
for (const auto& fw : flat_weights_) {
|
||||
if (!(fw.dtype() == dtype) ||
|
||||
!fw.is_cuda() ||
|
||||
!torch::cudnn_is_acceptable(fw)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If any parameters alias, we fall back to the slower, copying code path. This is
|
||||
// a sufficient check, because overlapping parameter buffers that don't completely
|
||||
// alias would break the assumptions of the uniqueness check in
|
||||
// Module::named_parameters().
|
||||
std::unordered_set<void*> unique_data_ptrs;
|
||||
for (const auto& p : flat_weights_) {
|
||||
unique_data_ptrs.emplace(p.data_ptr());
|
||||
}
|
||||
if (unique_data_ptrs.size() != flat_weights_.size()) {
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
NoGradGuard no_grad;
|
||||
const auto stdv = 1.0 / std::sqrt(options.hidden_size());
|
||||
for (auto& p : this->parameters()) {
|
||||
p.uniform_(-stdv, stdv);
|
||||
torch::DeviceGuard device_guard(first_fw.device());
|
||||
|
||||
// Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
|
||||
// an inplace operation on self.flat_weights_
|
||||
{
|
||||
torch::NoGradGuard no_grad;
|
||||
if (torch::_use_cudnn_rnn_flatten_weight()) {
|
||||
torch::_cudnn_rnn_flatten_weight(
|
||||
flat_weights_,
|
||||
options_base.bias() ? 4 : 2,
|
||||
options_base.input_size(),
|
||||
static_cast<int64_t>(get_cudnn_mode_for_rnn(options_base.mode())),
|
||||
options_base.hidden_size(),
|
||||
options_base.num_layers(),
|
||||
options_base.batch_first(),
|
||||
options_base.bidirectional());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
flatten_parameters();
|
||||
template <typename Derived>
|
||||
void RNNImplBase<Derived>::reset_flat_weights() {
|
||||
flat_weights_ = {};
|
||||
for (const auto& wn : flat_weights_names_) {
|
||||
auto named_parameters = this->named_parameters(/*recurse=*/false);
|
||||
if (named_parameters.contains(wn)) {
|
||||
flat_weights_.emplace_back(named_parameters[wn]);
|
||||
} else {
|
||||
flat_weights_.emplace_back(Tensor());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
@ -86,126 +213,113 @@ void RNNImplBase<Derived>::to(
|
||||
torch::Dtype dtype,
|
||||
bool non_blocking) {
|
||||
nn::Module::to(device, dtype, non_blocking);
|
||||
reset_flat_weights();
|
||||
flatten_parameters();
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
void RNNImplBase<Derived>::to(torch::Dtype dtype, bool non_blocking) {
|
||||
nn::Module::to(dtype, non_blocking);
|
||||
reset_flat_weights();
|
||||
flatten_parameters();
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
void RNNImplBase<Derived>::to(torch::Device device, bool non_blocking) {
|
||||
nn::Module::to(device, non_blocking);
|
||||
const auto num_directions = options.bidirectional() ? 2 : 1;
|
||||
for (int64_t layer = 0; layer < options.layers(); layer++) {
|
||||
for (auto direction = 0; direction < num_directions; direction++) {
|
||||
const auto layer_idx = (layer * num_directions) + direction;
|
||||
w_ih[layer_idx] = w_ih[layer_idx].to(device, non_blocking);
|
||||
w_hh[layer_idx] = w_hh[layer_idx].to(device, non_blocking);
|
||||
if (options.with_bias()) {
|
||||
b_ih[layer_idx] = b_ih[layer_idx].to(device, non_blocking);
|
||||
b_hh[layer_idx] = b_hh[layer_idx].to(device, non_blocking);
|
||||
}
|
||||
}
|
||||
}
|
||||
reset_flat_weights();
|
||||
flatten_parameters();
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
void RNNImplBase<Derived>::reset_parameters() {
|
||||
const double stdv = 1.0 / std::sqrt(options_base.hidden_size());
|
||||
for (auto& weight : this->parameters()) {
|
||||
init::uniform_(weight, -stdv, stdv);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
void RNNImplBase<Derived>::check_input(const Tensor& input, const Tensor& batch_sizes) const {
|
||||
int64_t expected_input_dim = batch_sizes.defined() ? 2 : 3;
|
||||
TORCH_CHECK(
|
||||
input.dim() == expected_input_dim,
|
||||
"input must have ", expected_input_dim, " dimensions, got ", input.dim());
|
||||
TORCH_CHECK(
|
||||
options_base.input_size() == input.size(-1),
|
||||
"input.size(-1) must be equal to input_size. Expected ", options_base.input_size(), ", got ", input.size(-1));
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
std::tuple<int64_t, int64_t, int64_t> RNNImplBase<Derived>::get_expected_hidden_size(
|
||||
const Tensor& input, const Tensor& batch_sizes) const {
|
||||
int64_t mini_batch = 0;
|
||||
if (batch_sizes.defined()) {
|
||||
mini_batch = batch_sizes[0].item<int64_t>();
|
||||
} else {
|
||||
mini_batch = options_base.batch_first() ? input.size(0) : input.size(1);
|
||||
}
|
||||
int64_t num_directions = options_base.bidirectional() ? 2 : 1;
|
||||
return std::make_tuple(options_base.num_layers() * num_directions, mini_batch, options_base.hidden_size());
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
void RNNImplBase<Derived>::check_hidden_size(
|
||||
const Tensor& hx,
|
||||
std::tuple<int64_t, int64_t, int64_t> expected_hidden_size,
|
||||
std::string msg) const {
|
||||
auto expected_hidden_size_vec = std::vector<int64_t>({
|
||||
std::get<0>(expected_hidden_size),
|
||||
std::get<1>(expected_hidden_size),
|
||||
std::get<2>(expected_hidden_size),
|
||||
});
|
||||
if (hx.sizes() != expected_hidden_size_vec) {
|
||||
msg = std::regex_replace(msg, std::regex("\\{1\\}"), c10::str(expected_hidden_size_vec));
|
||||
msg = std::regex_replace(msg, std::regex("\\{2\\}"), c10::str(hx.sizes()));
|
||||
TORCH_CHECK(false, msg);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
void RNNImplBase<Derived>::check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) const {
|
||||
this->check_input(input, batch_sizes);
|
||||
auto expected_hidden_size = this->get_expected_hidden_size(input, batch_sizes);
|
||||
|
||||
this->check_hidden_size(hidden, expected_hidden_size);
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
Tensor RNNImplBase<Derived>::permute_hidden(Tensor hx, const Tensor& permutation) const {
|
||||
if (!permutation.defined()) {
|
||||
return hx;
|
||||
}
|
||||
return apply_permutation(hx, permutation);
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
void RNNImplBase<Derived>::pretty_print(std::ostream& stream) const {
|
||||
const std::string name = this->name();
|
||||
const std::string name_without_impl = name.substr(0, name.size() - 4);
|
||||
stream << name_without_impl << "(input_size=" << options.input_size()
|
||||
<< ", hidden_size=" << options.hidden_size()
|
||||
<< ", layers=" << options.layers() << ", dropout=" << options.dropout()
|
||||
stream << std::boolalpha << name_without_impl << "(input_size=" << options_base.input_size()
|
||||
<< ", hidden_size=" << options_base.hidden_size()
|
||||
<< ", num_layers=" << options_base.num_layers()
|
||||
<< ", bias=" << options_base.bias()
|
||||
<< ", batch_first=" << options_base.batch_first()
|
||||
<< ", dropout=" << options_base.dropout()
|
||||
<< ", bidirectional=" << options_base.bidirectional()
|
||||
<< ")";
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
void RNNImplBase<Derived>::flatten_parameters() {
|
||||
// Cache the flattened weight and bias vector.
|
||||
flat_weights_ = flat_weights();
|
||||
|
||||
if (!cudnn_mode_ || !torch::cudnn_is_acceptable(w_ih.at(0))) {
|
||||
return;
|
||||
}
|
||||
|
||||
NoGradGuard no_grad;
|
||||
if (torch::_use_cudnn_rnn_flatten_weight()) {
|
||||
torch::_cudnn_rnn_flatten_weight(
|
||||
flat_weights_,
|
||||
/*weight_stride0=*/options.with_bias() ? 4 : 2,
|
||||
options.input_size(),
|
||||
static_cast<int64_t>(*cudnn_mode_),
|
||||
options.hidden_size(),
|
||||
options.layers(),
|
||||
/*batch_first=*/options.batch_first(),
|
||||
/*bidirectional=*/options.bidirectional());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
RNNOutput RNNImplBase<Derived>::generic_forward(
|
||||
std::function<RNNFunctionSignature> function,
|
||||
const Tensor& input,
|
||||
Tensor state) {
|
||||
if (!state.defined()) {
|
||||
// #layers, batch size, state size
|
||||
const auto batch_size = input.size(options.batch_first() ? 0 : 1);
|
||||
const auto num_directions = options.bidirectional() ? 2 : 1;
|
||||
state = torch::zeros(
|
||||
{options.layers() * num_directions, batch_size, options.hidden_size()},
|
||||
input.options());
|
||||
}
|
||||
Tensor output, new_state;
|
||||
std::tie(output, new_state) = function(
|
||||
input,
|
||||
std::move(state),
|
||||
flat_weights_,
|
||||
options.with_bias(),
|
||||
options.layers(),
|
||||
options.dropout(),
|
||||
this->is_training(),
|
||||
options.bidirectional(),
|
||||
options.batch_first());
|
||||
return {output, new_state};
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
std::vector<Tensor> RNNImplBase<Derived>::flat_weights() const {
|
||||
// Organize all weights in a flat vector in the order
|
||||
// (w_ih, w_hh, b_ih, b_hh), repeated for each layer (next to each other).
|
||||
std::vector<Tensor> flat;
|
||||
const auto num_directions = options.bidirectional() ? 2 : 1;
|
||||
for (int64_t layer = 0; layer < options.layers(); layer++) {
|
||||
for (auto direction = 0; direction < num_directions; direction++) {
|
||||
const auto layer_idx = (layer * num_directions) + direction;
|
||||
flat.push_back(w_ih[layer_idx]);
|
||||
flat.push_back(w_hh[layer_idx]);
|
||||
if (options.with_bias()) {
|
||||
flat.push_back(b_ih[layer_idx]);
|
||||
flat.push_back(b_hh[layer_idx]);
|
||||
}
|
||||
std::vector<Tensor> RNNImplBase<Derived>::all_weights() const {
|
||||
std::vector<Tensor> result = {};
|
||||
auto named_parameters = this->named_parameters(/*recurse=*/false);
|
||||
for (const auto& weights : all_weights_) {
|
||||
for (const auto& weight : weights) {
|
||||
result.emplace_back(named_parameters[weight]);
|
||||
}
|
||||
}
|
||||
return flat;
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
bool RNNImplBase<Derived>::any_parameters_alias() const {
|
||||
// If any parameters alias, we fall back to the slower, copying code path.
|
||||
// This is a sufficient check, because overlapping parameter buffers that
|
||||
// don't completely alias would break the assumptions of the uniqueness check
|
||||
// in Module.named_parameters().
|
||||
std::unordered_set<void*> unique_data_ptrs;
|
||||
auto params = this->parameters();
|
||||
unique_data_ptrs.reserve(params.size());
|
||||
for (const auto& p : params) {
|
||||
unique_data_ptrs.emplace(p.data_ptr());
|
||||
}
|
||||
return unique_data_ptrs.size() != params.size();
|
||||
return result;
|
||||
}
|
||||
|
||||
template class RNNImplBase<LSTMImpl>;
|
||||
@ -215,91 +329,275 @@ template class RNNImplBase<RNNImpl>;
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
RNNImpl::RNNImpl(const RNNOptions& options_)
|
||||
: detail::RNNImplBase<RNNImpl>(
|
||||
detail::RNNOptionsBase(options_.input_size(), options_.hidden_size())
|
||||
.layers(options_.layers())
|
||||
.with_bias(options_.with_bias())
|
||||
.dropout(options_.dropout())
|
||||
.bidirectional(options_.bidirectional())
|
||||
.batch_first(options_.batch_first()),
|
||||
static_cast<CuDNNMode>(options_.activation())),
|
||||
options(options_) {}
|
||||
|
||||
void RNNImpl::pretty_print(std::ostream& stream) const {
|
||||
stream << "torch::nn::RNN(input_size=" << options.input_size()
|
||||
<< ", hidden_size=" << options.hidden_size()
|
||||
<< ", layers=" << options.layers() << ", dropout=" << options.dropout()
|
||||
<< ", activation="
|
||||
<< (options.activation() == RNNActivation::Tanh ? "tanh" : "relu")
|
||||
<< ")";
|
||||
detail::RNNOptionsBase::rnn_options_base_mode_t compute_rnn_options_base_mode(
|
||||
RNNOptions::nonlinearity_t nonlinearity) {
|
||||
if (c10::get_if<enumtype::kTanh>(&nonlinearity)) {
|
||||
return torch::kRNN_TANH;
|
||||
} else if (c10::get_if<enumtype::kReLU>(&nonlinearity)) {
|
||||
return torch::kRNN_RELU;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unknown nonlinearity ", torch::enumtype::get_enum_name(nonlinearity));
|
||||
}
|
||||
}
|
||||
|
||||
RNNOutput RNNImpl::forward(const Tensor& input, Tensor state) {
|
||||
switch (options.activation()) {
|
||||
case RNNActivation::ReLU:
|
||||
return generic_forward(
|
||||
static_cast<RNNFunctionSignature*>(&torch::rnn_relu),
|
||||
input,
|
||||
std::move(state));
|
||||
case RNNActivation::Tanh:
|
||||
return generic_forward(
|
||||
static_cast<RNNFunctionSignature*>(&torch::rnn_tanh),
|
||||
input,
|
||||
std::move(state));
|
||||
default:
|
||||
AT_ERROR("Unhandled RNN activation function!");
|
||||
RNNImpl::RNNImpl(const RNNOptions& options_)
|
||||
: detail::RNNImplBase<RNNImpl>(
|
||||
detail::RNNOptionsBase(
|
||||
compute_rnn_options_base_mode(options_.nonlinearity()),
|
||||
options_.input_size(),
|
||||
options_.hidden_size())
|
||||
.num_layers(options_.num_layers())
|
||||
.bias(options_.bias())
|
||||
.batch_first(options_.batch_first())
|
||||
.dropout(options_.dropout())
|
||||
.bidirectional(options_.bidirectional())),
|
||||
options(options_) {}
|
||||
|
||||
std::tuple<Tensor, Tensor> RNNImpl::forward_helper(
|
||||
const Tensor& input,
|
||||
const Tensor& batch_sizes,
|
||||
const Tensor& sorted_indices,
|
||||
int64_t max_batch_size,
|
||||
Tensor hx) {
|
||||
if (!hx.defined()) {
|
||||
int64_t num_directions = options_base.bidirectional() ? 2 : 1;
|
||||
hx = torch::zeros({options_base.num_layers() * num_directions,
|
||||
max_batch_size, options_base.hidden_size()},
|
||||
torch::dtype(input.dtype()).device(input.device()));
|
||||
} else {
|
||||
// Each batch of the hidden state should match the input sequence that
|
||||
// the user believes he/she is passing in.
|
||||
hx = this->permute_hidden(hx, sorted_indices);
|
||||
}
|
||||
|
||||
this->check_forward_args(input, hx, batch_sizes);
|
||||
|
||||
std::tuple<Tensor, Tensor> result;
|
||||
if (!batch_sizes.defined()) {
|
||||
if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
|
||||
result = torch::rnn_tanh(input, hx, flat_weights_, options_base.bias(), options_base.num_layers(),
|
||||
options_base.dropout(), this->is_training(), options_base.bidirectional(), options_base.batch_first());
|
||||
} else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
|
||||
result = torch::rnn_relu(input, hx, flat_weights_, options_base.bias(), options_base.num_layers(),
|
||||
options_base.dropout(), this->is_training(), options_base.bidirectional(), options_base.batch_first());
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(options_base.mode()));
|
||||
}
|
||||
} else {
|
||||
if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
|
||||
result = torch::rnn_tanh(input, batch_sizes, hx, flat_weights_, options_base.bias(),
|
||||
options_base.num_layers(), options_base.dropout(), this->is_training(), options_base.bidirectional());
|
||||
} else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
|
||||
result = torch::rnn_relu(input, batch_sizes, hx, flat_weights_, options_base.bias(),
|
||||
options_base.num_layers(), options_base.dropout(), this->is_training(), options_base.bidirectional());
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(options_base.mode()));
|
||||
}
|
||||
}
|
||||
auto output = std::get<0>(result);
|
||||
auto hidden = std::get<1>(result);
|
||||
|
||||
return std::make_tuple(output, hidden);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> RNNImpl::forward(const Tensor& input, Tensor hx) {
|
||||
auto batch_sizes = torch::Tensor();
|
||||
auto max_batch_size = options_base.batch_first() ? input.size(0) : input.size(1);
|
||||
auto sorted_indices = torch::Tensor();
|
||||
auto unsorted_indices = torch::Tensor();
|
||||
|
||||
Tensor output, hidden;
|
||||
std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
|
||||
|
||||
return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices));
|
||||
}
|
||||
|
||||
std::tuple<PackedSequence, Tensor> RNNImpl::forward_with_packed_input(const PackedSequence& packed_input, Tensor hx) {
|
||||
const auto& input = packed_input.data();
|
||||
const auto& batch_sizes = packed_input.batch_sizes();
|
||||
const auto& sorted_indices = packed_input.sorted_indices();
|
||||
const auto& unsorted_indices = packed_input.unsorted_indices();
|
||||
auto max_batch_size = batch_sizes[0].item<int64_t>();
|
||||
|
||||
Tensor output, hidden;
|
||||
std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
|
||||
|
||||
auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
|
||||
return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices));
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
LSTMImpl::LSTMImpl(const LSTMOptions& options_)
|
||||
: detail::RNNImplBase<LSTMImpl>(
|
||||
options_,
|
||||
CuDNNMode::LSTM,
|
||||
/*number_of_gates=*/4) {}
|
||||
detail::RNNOptionsBase(
|
||||
torch::kLSTM,
|
||||
options_.input_size(),
|
||||
options_.hidden_size())
|
||||
.num_layers(options_.num_layers())
|
||||
.bias(options_.bias())
|
||||
.batch_first(options_.batch_first())
|
||||
.dropout(options_.dropout())
|
||||
.bidirectional(options_.bidirectional())),
|
||||
options(options_) {}
|
||||
|
||||
RNNOutput LSTMImpl::forward(const Tensor& input, Tensor state) {
|
||||
// It would be trickier to adapt the `generic_forward` for the LSTM because
|
||||
// its output has a different dimensionality (3-tuple vs. 2-tuple), while we
|
||||
// always return one state variable (stacking the hidden/cell state into one),
|
||||
// which also makes the state variables going into the `generic_forward`, and
|
||||
// the way we default-initialize the state when it is not passed, slightly
|
||||
// different. So we just re-implement it specifically for the LSTM here.
|
||||
if (!state.defined()) {
|
||||
// 2 for hidden state and cell state, then #layers, batch size, state size
|
||||
const auto batch_size = input.size(options.batch_first() ? 0 : 1);
|
||||
const auto num_directions = options.bidirectional() ? 2 : 1;
|
||||
state = torch::zeros(
|
||||
{2, options.layers() * num_directions, batch_size, options.hidden_size()},
|
||||
input.options());
|
||||
void LSTMImpl::check_forward_args(const Tensor& input, std::tuple<Tensor, Tensor> hidden, const Tensor& batch_sizes) const {
|
||||
this->check_input(input, batch_sizes);
|
||||
auto expected_hidden_size = this->get_expected_hidden_size(input, batch_sizes);
|
||||
|
||||
this->check_hidden_size(std::get<0>(hidden), expected_hidden_size,
|
||||
"Expected hidden[0] size {1}, got {2}");
|
||||
this->check_hidden_size(std::get<1>(hidden), expected_hidden_size,
|
||||
"Expected hidden[1] size {1}, got {2}");
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> LSTMImpl::permute_hidden(std::tuple<Tensor, Tensor> hx, const Tensor& permutation) const {
|
||||
if (!permutation.defined()) {
|
||||
return hx;
|
||||
}
|
||||
Tensor output, hidden_state, cell_state;
|
||||
std::tie(output, hidden_state, cell_state) = torch::lstm(
|
||||
input,
|
||||
{state[0], state[1]},
|
||||
flat_weights_,
|
||||
options.with_bias(),
|
||||
options.layers(),
|
||||
options.dropout(),
|
||||
this->is_training(),
|
||||
options.bidirectional(),
|
||||
options.batch_first());
|
||||
return {output, torch::stack({hidden_state, cell_state})};
|
||||
return std::make_tuple(
|
||||
apply_permutation(std::get<0>(hx), permutation),
|
||||
apply_permutation(std::get<1>(hx), permutation)
|
||||
);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward_helper(
|
||||
const Tensor& input,
|
||||
const Tensor& batch_sizes,
|
||||
const Tensor& sorted_indices,
|
||||
int64_t max_batch_size,
|
||||
torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
|
||||
|
||||
std::tuple<Tensor, Tensor> hx;
|
||||
if (!hx_opt.has_value()) {
|
||||
int64_t num_directions = options.bidirectional() ? 2 : 1;
|
||||
auto zeros = torch::zeros({options.num_layers() * num_directions,
|
||||
max_batch_size, options.hidden_size()},
|
||||
torch::dtype(input.dtype()).device(input.device()));
|
||||
hx = std::make_tuple(zeros, zeros);
|
||||
} else {
|
||||
hx = hx_opt.value();
|
||||
// Each batch of the hidden state should match the input sequence that
|
||||
// the user believes he/she is passing in.
|
||||
hx = this->permute_hidden(hx, sorted_indices);
|
||||
}
|
||||
|
||||
this->check_forward_args(input, hx, batch_sizes);
|
||||
std::tuple<Tensor, Tensor, Tensor> result;
|
||||
if (!batch_sizes.defined()) {
|
||||
result = torch::lstm(input, {std::get<0>(hx), std::get<1>(hx)}, flat_weights_, options.bias(), options.num_layers(),
|
||||
options.dropout(), this->is_training(), options.bidirectional(), options.batch_first());
|
||||
} else {
|
||||
result = torch::lstm(input, batch_sizes, {std::get<0>(hx), std::get<1>(hx)}, flat_weights_, options.bias(),
|
||||
options.num_layers(), options.dropout(), this->is_training(), options.bidirectional());
|
||||
}
|
||||
auto output = std::get<0>(result);
|
||||
auto hidden = std::make_tuple(std::get<1>(result), std::get<2>(result));
|
||||
|
||||
return std::make_tuple(output, hidden);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward(
|
||||
const Tensor& input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
|
||||
auto batch_sizes = torch::Tensor();
|
||||
auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
|
||||
auto sorted_indices = torch::Tensor();
|
||||
auto unsorted_indices = torch::Tensor();
|
||||
|
||||
Tensor output;
|
||||
std::tuple<Tensor, Tensor> hidden;
|
||||
std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt));
|
||||
|
||||
return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices));
|
||||
}
|
||||
|
||||
std::tuple<PackedSequence, std::tuple<Tensor, Tensor>> LSTMImpl::forward_with_packed_input(
|
||||
const PackedSequence& packed_input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
|
||||
const auto& input = packed_input.data();
|
||||
const auto& batch_sizes = packed_input.batch_sizes();
|
||||
const auto& sorted_indices = packed_input.sorted_indices();
|
||||
const auto& unsorted_indices = packed_input.unsorted_indices();
|
||||
auto max_batch_size = batch_sizes[0].item<int64_t>();
|
||||
|
||||
Tensor output;
|
||||
std::tuple<Tensor, Tensor> hidden;
|
||||
std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt));
|
||||
|
||||
auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
|
||||
return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices));
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
GRUImpl::GRUImpl(const GRUOptions& options_)
|
||||
: detail::RNNImplBase<GRUImpl>(
|
||||
options_,
|
||||
CuDNNMode::GRU,
|
||||
/*number_of_gates=*/3) {}
|
||||
detail::RNNOptionsBase(
|
||||
torch::kGRU,
|
||||
options_.input_size(),
|
||||
options_.hidden_size())
|
||||
.num_layers(options_.num_layers())
|
||||
.bias(options_.bias())
|
||||
.batch_first(options_.batch_first())
|
||||
.dropout(options_.dropout())
|
||||
.bidirectional(options_.bidirectional())),
|
||||
options(options_) {}
|
||||
|
||||
RNNOutput GRUImpl::forward(const Tensor& input, Tensor state) {
|
||||
return generic_forward(
|
||||
static_cast<RNNFunctionSignature*>(&torch::gru), input, std::move(state));
|
||||
std::tuple<Tensor, Tensor> GRUImpl::forward_helper(
|
||||
const Tensor& input,
|
||||
const Tensor& batch_sizes,
|
||||
const Tensor& sorted_indices,
|
||||
int64_t max_batch_size,
|
||||
Tensor hx) {
|
||||
if (!hx.defined()) {
|
||||
int64_t num_directions = options.bidirectional() ? 2 : 1;
|
||||
hx = torch::zeros({options.num_layers() * num_directions,
|
||||
max_batch_size, options.hidden_size()},
|
||||
torch::dtype(input.dtype()).device(input.device()));
|
||||
} else {
|
||||
// Each batch of the hidden state should match the input sequence that
|
||||
// the user believes he/she is passing in.
|
||||
hx = this->permute_hidden(hx, sorted_indices);
|
||||
}
|
||||
|
||||
this->check_forward_args(input, hx, batch_sizes);
|
||||
std::tuple<Tensor, Tensor> result;
|
||||
if (!batch_sizes.defined()) {
|
||||
result = torch::gru(input, hx, flat_weights_, options.bias(), options.num_layers(),
|
||||
options.dropout(), this->is_training(), options.bidirectional(), options.batch_first());
|
||||
} else {
|
||||
result = torch::gru(input, batch_sizes, hx, flat_weights_, options.bias(),
|
||||
options.num_layers(), options.dropout(), this->is_training(), options.bidirectional());
|
||||
}
|
||||
auto output = std::get<0>(result);
|
||||
auto hidden = std::get<1>(result);
|
||||
|
||||
return std::make_tuple(output, hidden);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> GRUImpl::forward(const Tensor& input, Tensor hx) {
|
||||
auto batch_sizes = torch::Tensor();
|
||||
auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
|
||||
auto sorted_indices = torch::Tensor();
|
||||
auto unsorted_indices = torch::Tensor();
|
||||
|
||||
Tensor output, hidden;
|
||||
std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
|
||||
|
||||
return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices));
|
||||
}
|
||||
|
||||
std::tuple<PackedSequence, Tensor> GRUImpl::forward_with_packed_input(const PackedSequence& packed_input, Tensor hx) {
|
||||
const auto& input = packed_input.data();
|
||||
const auto& batch_sizes = packed_input.batch_sizes();
|
||||
const auto& sorted_indices = packed_input.sorted_indices();
|
||||
const auto& unsorted_indices = packed_input.unsorted_indices();
|
||||
auto max_batch_size = batch_sizes[0].item<int64_t>();
|
||||
|
||||
Tensor output, hidden;
|
||||
std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
|
||||
|
||||
auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
|
||||
return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices));
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -5,21 +5,19 @@ namespace nn {
|
||||
|
||||
namespace detail {
|
||||
|
||||
RNNOptionsBase::RNNOptionsBase(int64_t input_size, int64_t hidden_size)
|
||||
: input_size_(input_size), hidden_size_(hidden_size) {}
|
||||
RNNOptionsBase::RNNOptionsBase(rnn_options_base_mode_t mode, int64_t input_size, int64_t hidden_size)
|
||||
: mode_(mode), input_size_(input_size), hidden_size_(hidden_size) {}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
RNNOptions::RNNOptions(int64_t input_size, int64_t hidden_size)
|
||||
: input_size_(input_size), hidden_size_(hidden_size) {}
|
||||
|
||||
RNNOptions& RNNOptions::tanh() {
|
||||
return activation(RNNActivation::Tanh);
|
||||
}
|
||||
LSTMOptions::LSTMOptions(int64_t input_size, int64_t hidden_size)
|
||||
: input_size_(input_size), hidden_size_(hidden_size) {}
|
||||
|
||||
RNNOptions& RNNOptions::relu() {
|
||||
return activation(RNNActivation::ReLU);
|
||||
}
|
||||
GRUOptions::GRUOptions(int64_t input_size, int64_t hidden_size)
|
||||
: input_size_(input_size), hidden_size_(hidden_size) {}
|
||||
|
||||
namespace detail {
|
||||
|
||||
|
Reference in New Issue
Block a user