[C++ API] Allow skipping default arguments in module's forward method when module is used in Sequential (#33027)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33027

This PR allows default arguments in module's forward method to be skipped when module is used in `torch::nn::Sequential`, by introducing the `FORWARD_HAS_DEFAULT_ARGS` macro and requiring that all modules that have default arguments in its forward method must have a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro call.

Fixes issue mentioned in https://github.com/pytorch/pytorch/issues/30931#issuecomment-564144468.

Test Plan: Imported from OSS

Differential Revision: D19777815

Pulled By: yf225

fbshipit-source-id: 73282fcf63377530063e0092a9d84b6c139d2e32
This commit is contained in:
Will Feng
2020-02-17 20:33:51 -08:00
committed by Facebook Github Bot
parent 4724964810
commit a203dc2e6d
11 changed files with 470 additions and 18 deletions

View File

@ -436,3 +436,208 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
")");
}
TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) {
{
Sequential sequential(Identity(), ConvTranspose1d(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false)));
std::dynamic_pointer_cast<ConvTranspose1dImpl>(sequential[1])->weight.set_data(torch::arange(18.).reshape({3, 2, 3}));
auto x = torch::arange(30.).reshape({2, 3, 5});
auto y = sequential->forward(x);
auto expected = torch::tensor({{{ 150., 333., 552., 615., 678., 501., 276.},
{ 195., 432., 714., 804., 894., 654., 357.}},
{{ 420., 918., 1497., 1560., 1623., 1176., 636.},
{ 600., 1287., 2064., 2154., 2244., 1599., 852.}}});
ASSERT_TRUE(torch::allclose(y, expected));
}
{
Sequential sequential(Identity(), ConvTranspose2d(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false)));
std::dynamic_pointer_cast<ConvTranspose2dImpl>(sequential[1])->weight.set_data(torch::arange(54.).reshape({3, 2, 3, 3}));
auto x = torch::arange(75.).reshape({1, 3, 5, 5});
auto y = sequential->forward(x);
auto expected = torch::tensor({{{{ 2250., 4629., 7140., 7311., 7482., 5133., 2640.},
{ 4995., 10272., 15837., 16206., 16575., 11364., 5841.},
{ 8280., 17019., 26226., 26820., 27414., 18783., 9648.},
{ 9225., 18954., 29196., 29790., 30384., 20808., 10683.},
{10170., 20889., 32166., 32760., 33354., 22833., 11718.},
{ 7515., 15420., 23721., 24144., 24567., 16800., 8613.},
{ 4140., 8487., 13044., 13269., 13494., 9219., 4722.}},
{{ 2925., 6006., 9246., 9498., 9750., 6672., 3423.},
{ 6480., 13296., 20454., 20985., 21516., 14712., 7542.},
{10710., 21960., 33759., 34596., 35433., 24210., 12402.},
{12060., 24705., 37944., 38781., 39618., 27045., 13842.},
{13410., 27450., 42129., 42966., 43803., 29880., 15282.},
{ 9810., 20064., 30768., 31353., 31938., 21768., 11124.},
{ 5355., 10944., 16770., 17076., 17382., 11838., 6045.}}}});
ASSERT_TRUE(torch::allclose(y, expected));
}
{
Sequential sequential(Identity(), ConvTranspose3d(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false)));
std::dynamic_pointer_cast<ConvTranspose3dImpl>(sequential[1])->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2}));
auto x = torch::arange(16.).reshape({1, 2, 2, 2, 2});
auto y = sequential->forward(x);
auto expected = torch::tensor({{{{{ 128., 280., 154.},
{ 304., 664., 364.},
{ 184., 400., 218.}},
{{ 352., 768., 420.},
{ 832., 1808., 984.},
{ 496., 1072., 580.}},
{{ 256., 552., 298.},
{ 592., 1272., 684.},
{ 344., 736., 394.}}},
{{{ 192., 424., 234.},
{ 464., 1016., 556.},
{ 280., 608., 330.}},
{{ 544., 1184., 644.},
{1280., 2768., 1496.},
{ 752., 1616., 868.}},
{{ 384., 824., 442.},
{ 880., 1880., 1004.},
{ 504., 1072., 570.}}}}});
ASSERT_TRUE(torch::allclose(y, expected));
}
{
auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
Sequential sequential(Identity(), EmbeddingBag::from_pretrained(weight));
auto x = torch::tensor({{1, 0}}, torch::kLong);
auto y = sequential->forward(x);
auto expected = torch::tensor({2.5000, 3.7000, 4.6500});
ASSERT_TRUE(torch::allclose(y, expected));
}
{
torch::manual_seed(0);
int64_t embed_dim = 8;
int64_t num_heads = 4;
int64_t batch_size = 8;
int64_t src_len = 3;
int64_t tgt_len = 1;
auto query = torch::ones({batch_size, tgt_len, embed_dim});
auto key = torch::ones({batch_size, src_len, embed_dim});
auto value = key;
Sequential sequential(MultiheadAttention(embed_dim, num_heads));
auto output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1));
auto attn_output = std::get<0>(output);
auto attn_output_expected = torch::tensor(
{{{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}}});
ASSERT_TRUE(torch::allclose(attn_output, attn_output_expected, 1e-05, 2e-04));
auto attn_output_weights = std::get<1>(output);
auto attn_output_weights_expected = torch::tensor(
{{{0.3333, 0.3333, 0.3333}},
{{0.3333, 0.3333, 0.3333}},
{{0.3333, 0.3333, 0.3333}},
{{0.3333, 0.3333, 0.3333}},
{{0.3333, 0.3333, 0.3333}},
{{0.3333, 0.3333, 0.3333}},
{{0.3333, 0.3333, 0.3333}},
{{0.3333, 0.3333, 0.3333}}});
ASSERT_TRUE(torch::allclose(attn_output_weights, attn_output_weights_expected, 1e-05, 2e-04));
}
{
auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
auto x = torch::tensor({{{2, 4, 5}}}, torch::dtype(torch::kFloat));
Sequential sequential(MaxUnpool1d(3));
auto y = sequential->forward(x, indices);
auto expected = torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat);
ASSERT_TRUE(torch::allclose(y, expected));
}
{
auto indices = torch::tensor(
{{{{ 6, 8, 9},
{16, 18, 19},
{21, 23, 24}}},
{{{ 6, 8, 9},
{16, 18, 19},
{21, 23, 24}}}}, torch::kLong);
auto x = torch::tensor(
{{{{ 6, 8, 9},
{16, 18, 19},
{21, 23, 24}}},
{{{31, 33, 34},
{41, 43, 44},
{46, 48, 49}}}}, torch::dtype(torch::kFloat));
Sequential sequential(MaxUnpool2d(MaxUnpool2dOptions(3).stride(2).padding(1)));
auto y = sequential->forward(x, indices);
auto expected = torch::tensor(
{{{{ 0, 0, 0, 0, 0},
{ 0, 6, 0, 8, 9},
{ 0, 0, 0, 0, 0},
{ 0, 16, 0, 18, 19},
{ 0, 21, 0, 23, 24}}},
{{{ 0, 0, 0, 0, 0},
{ 0, 31, 0, 33, 34},
{ 0, 0, 0, 0, 0},
{ 0, 41, 0, 43, 44},
{ 0, 46, 0, 48, 49}}}} , torch::kFloat);
ASSERT_TRUE(torch::allclose(y, expected));
}
{
auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
auto x = torch::tensor({{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true));
Sequential sequential(MaxUnpool3d(3));
auto y = sequential->forward(x, indices);
auto expected = torch::tensor(
{{{{{ 0, 0, 0},
{ 0, 0, 0},
{ 0, 0, 0}},
{{ 0, 0, 0},
{ 0, 0, 0},
{ 0, 0, 0}},
{{ 0, 0, 0},
{ 0, 0, 0},
{ 0, 0, 26}}}}}, torch::kFloat);
ASSERT_TRUE(torch::allclose(y, expected));
}
{
torch::manual_seed(0);
Sequential sequential(Identity(), RNN(2, 3));
auto x = torch::ones({2, 3, 2});
auto rnn_output = sequential->forward<RNNOutput>(x);
auto expected_output = torch::tensor(
{{{0.0000, 0.0000, 0.4886},
{0.0000, 0.0000, 0.4886},
{0.0000, 0.0000, 0.4886}},
{{0.0000, 0.0000, 0.3723},
{0.0000, 0.0000, 0.3723},
{0.0000, 0.0000, 0.3723}}});
ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
}
{
torch::manual_seed(0);
Sequential sequential(Identity(), LSTM(2, 3));
auto x = torch::ones({2, 3, 2});
auto rnn_output = sequential->forward<RNNOutput>(x);
auto expected_output = torch::tensor(
{{{-0.2693, -0.1240, 0.0744},
{-0.2693, -0.1240, 0.0744},
{-0.2693, -0.1240, 0.0744}},
{{-0.3889, -0.1919, 0.1183},
{-0.3889, -0.1919, 0.1183},
{-0.3889, -0.1919, 0.1183}}});
ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
}
{
torch::manual_seed(0);
Sequential sequential(Identity(), GRU(2, 3));
auto x = torch::ones({2, 3, 2});
auto rnn_output = sequential->forward<RNNOutput>(x);
auto expected_output = torch::tensor(
{{{-0.1134, 0.0467, 0.2336},
{-0.1134, 0.0467, 0.2336},
{-0.1134, 0.0467, 0.2336}},
{{-0.1189, 0.0502, 0.2960},
{-0.1189, 0.0502, 0.2960},
{-0.1189, 0.0502, 0.2960}}});
ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
}
}