mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 03:04:55 +08:00
[C++ API] Allow skipping default arguments in module's forward method when module is used in Sequential (#33027)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33027 This PR allows default arguments in module's forward method to be skipped when module is used in `torch::nn::Sequential`, by introducing the `FORWARD_HAS_DEFAULT_ARGS` macro and requiring that all modules that have default arguments in its forward method must have a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro call. Fixes issue mentioned in https://github.com/pytorch/pytorch/issues/30931#issuecomment-564144468. Test Plan: Imported from OSS Differential Revision: D19777815 Pulled By: yf225 fbshipit-source-id: 73282fcf63377530063e0092a9d84b6c139d2e32
This commit is contained in:
committed by
Facebook Github Bot
parent
4724964810
commit
a203dc2e6d
@ -436,3 +436,208 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
|
||||
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
|
||||
")");
|
||||
}
|
||||
|
||||
TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) {
|
||||
{
|
||||
Sequential sequential(Identity(), ConvTranspose1d(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false)));
|
||||
std::dynamic_pointer_cast<ConvTranspose1dImpl>(sequential[1])->weight.set_data(torch::arange(18.).reshape({3, 2, 3}));
|
||||
auto x = torch::arange(30.).reshape({2, 3, 5});
|
||||
auto y = sequential->forward(x);
|
||||
auto expected = torch::tensor({{{ 150., 333., 552., 615., 678., 501., 276.},
|
||||
{ 195., 432., 714., 804., 894., 654., 357.}},
|
||||
{{ 420., 918., 1497., 1560., 1623., 1176., 636.},
|
||||
{ 600., 1287., 2064., 2154., 2244., 1599., 852.}}});
|
||||
ASSERT_TRUE(torch::allclose(y, expected));
|
||||
}
|
||||
{
|
||||
Sequential sequential(Identity(), ConvTranspose2d(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false)));
|
||||
std::dynamic_pointer_cast<ConvTranspose2dImpl>(sequential[1])->weight.set_data(torch::arange(54.).reshape({3, 2, 3, 3}));
|
||||
auto x = torch::arange(75.).reshape({1, 3, 5, 5});
|
||||
auto y = sequential->forward(x);
|
||||
auto expected = torch::tensor({{{{ 2250., 4629., 7140., 7311., 7482., 5133., 2640.},
|
||||
{ 4995., 10272., 15837., 16206., 16575., 11364., 5841.},
|
||||
{ 8280., 17019., 26226., 26820., 27414., 18783., 9648.},
|
||||
{ 9225., 18954., 29196., 29790., 30384., 20808., 10683.},
|
||||
{10170., 20889., 32166., 32760., 33354., 22833., 11718.},
|
||||
{ 7515., 15420., 23721., 24144., 24567., 16800., 8613.},
|
||||
{ 4140., 8487., 13044., 13269., 13494., 9219., 4722.}},
|
||||
{{ 2925., 6006., 9246., 9498., 9750., 6672., 3423.},
|
||||
{ 6480., 13296., 20454., 20985., 21516., 14712., 7542.},
|
||||
{10710., 21960., 33759., 34596., 35433., 24210., 12402.},
|
||||
{12060., 24705., 37944., 38781., 39618., 27045., 13842.},
|
||||
{13410., 27450., 42129., 42966., 43803., 29880., 15282.},
|
||||
{ 9810., 20064., 30768., 31353., 31938., 21768., 11124.},
|
||||
{ 5355., 10944., 16770., 17076., 17382., 11838., 6045.}}}});
|
||||
ASSERT_TRUE(torch::allclose(y, expected));
|
||||
}
|
||||
{
|
||||
Sequential sequential(Identity(), ConvTranspose3d(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false)));
|
||||
std::dynamic_pointer_cast<ConvTranspose3dImpl>(sequential[1])->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2}));
|
||||
auto x = torch::arange(16.).reshape({1, 2, 2, 2, 2});
|
||||
auto y = sequential->forward(x);
|
||||
auto expected = torch::tensor({{{{{ 128., 280., 154.},
|
||||
{ 304., 664., 364.},
|
||||
{ 184., 400., 218.}},
|
||||
{{ 352., 768., 420.},
|
||||
{ 832., 1808., 984.},
|
||||
{ 496., 1072., 580.}},
|
||||
{{ 256., 552., 298.},
|
||||
{ 592., 1272., 684.},
|
||||
{ 344., 736., 394.}}},
|
||||
{{{ 192., 424., 234.},
|
||||
{ 464., 1016., 556.},
|
||||
{ 280., 608., 330.}},
|
||||
{{ 544., 1184., 644.},
|
||||
{1280., 2768., 1496.},
|
||||
{ 752., 1616., 868.}},
|
||||
{{ 384., 824., 442.},
|
||||
{ 880., 1880., 1004.},
|
||||
{ 504., 1072., 570.}}}}});
|
||||
ASSERT_TRUE(torch::allclose(y, expected));
|
||||
}
|
||||
{
|
||||
auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
|
||||
Sequential sequential(Identity(), EmbeddingBag::from_pretrained(weight));
|
||||
auto x = torch::tensor({{1, 0}}, torch::kLong);
|
||||
auto y = sequential->forward(x);
|
||||
auto expected = torch::tensor({2.5000, 3.7000, 4.6500});
|
||||
ASSERT_TRUE(torch::allclose(y, expected));
|
||||
}
|
||||
{
|
||||
torch::manual_seed(0);
|
||||
|
||||
int64_t embed_dim = 8;
|
||||
int64_t num_heads = 4;
|
||||
int64_t batch_size = 8;
|
||||
int64_t src_len = 3;
|
||||
int64_t tgt_len = 1;
|
||||
|
||||
auto query = torch::ones({batch_size, tgt_len, embed_dim});
|
||||
auto key = torch::ones({batch_size, src_len, embed_dim});
|
||||
auto value = key;
|
||||
|
||||
Sequential sequential(MultiheadAttention(embed_dim, num_heads));
|
||||
auto output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1));
|
||||
|
||||
auto attn_output = std::get<0>(output);
|
||||
auto attn_output_expected = torch::tensor(
|
||||
{{{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
|
||||
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
|
||||
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
|
||||
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
|
||||
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
|
||||
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
|
||||
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
|
||||
{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}}});
|
||||
ASSERT_TRUE(torch::allclose(attn_output, attn_output_expected, 1e-05, 2e-04));
|
||||
|
||||
auto attn_output_weights = std::get<1>(output);
|
||||
auto attn_output_weights_expected = torch::tensor(
|
||||
{{{0.3333, 0.3333, 0.3333}},
|
||||
{{0.3333, 0.3333, 0.3333}},
|
||||
{{0.3333, 0.3333, 0.3333}},
|
||||
{{0.3333, 0.3333, 0.3333}},
|
||||
{{0.3333, 0.3333, 0.3333}},
|
||||
{{0.3333, 0.3333, 0.3333}},
|
||||
{{0.3333, 0.3333, 0.3333}},
|
||||
{{0.3333, 0.3333, 0.3333}}});
|
||||
ASSERT_TRUE(torch::allclose(attn_output_weights, attn_output_weights_expected, 1e-05, 2e-04));
|
||||
}
|
||||
{
|
||||
auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
|
||||
auto x = torch::tensor({{{2, 4, 5}}}, torch::dtype(torch::kFloat));
|
||||
Sequential sequential(MaxUnpool1d(3));
|
||||
auto y = sequential->forward(x, indices);
|
||||
auto expected = torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat);
|
||||
ASSERT_TRUE(torch::allclose(y, expected));
|
||||
}
|
||||
{
|
||||
auto indices = torch::tensor(
|
||||
{{{{ 6, 8, 9},
|
||||
{16, 18, 19},
|
||||
{21, 23, 24}}},
|
||||
{{{ 6, 8, 9},
|
||||
{16, 18, 19},
|
||||
{21, 23, 24}}}}, torch::kLong);
|
||||
auto x = torch::tensor(
|
||||
{{{{ 6, 8, 9},
|
||||
{16, 18, 19},
|
||||
{21, 23, 24}}},
|
||||
{{{31, 33, 34},
|
||||
{41, 43, 44},
|
||||
{46, 48, 49}}}}, torch::dtype(torch::kFloat));
|
||||
Sequential sequential(MaxUnpool2d(MaxUnpool2dOptions(3).stride(2).padding(1)));
|
||||
auto y = sequential->forward(x, indices);
|
||||
auto expected = torch::tensor(
|
||||
{{{{ 0, 0, 0, 0, 0},
|
||||
{ 0, 6, 0, 8, 9},
|
||||
{ 0, 0, 0, 0, 0},
|
||||
{ 0, 16, 0, 18, 19},
|
||||
{ 0, 21, 0, 23, 24}}},
|
||||
{{{ 0, 0, 0, 0, 0},
|
||||
{ 0, 31, 0, 33, 34},
|
||||
{ 0, 0, 0, 0, 0},
|
||||
{ 0, 41, 0, 43, 44},
|
||||
{ 0, 46, 0, 48, 49}}}} , torch::kFloat);
|
||||
ASSERT_TRUE(torch::allclose(y, expected));
|
||||
}
|
||||
{
|
||||
auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
|
||||
auto x = torch::tensor({{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true));
|
||||
Sequential sequential(MaxUnpool3d(3));
|
||||
auto y = sequential->forward(x, indices);
|
||||
auto expected = torch::tensor(
|
||||
{{{{{ 0, 0, 0},
|
||||
{ 0, 0, 0},
|
||||
{ 0, 0, 0}},
|
||||
{{ 0, 0, 0},
|
||||
{ 0, 0, 0},
|
||||
{ 0, 0, 0}},
|
||||
{{ 0, 0, 0},
|
||||
{ 0, 0, 0},
|
||||
{ 0, 0, 26}}}}}, torch::kFloat);
|
||||
ASSERT_TRUE(torch::allclose(y, expected));
|
||||
}
|
||||
{
|
||||
torch::manual_seed(0);
|
||||
Sequential sequential(Identity(), RNN(2, 3));
|
||||
auto x = torch::ones({2, 3, 2});
|
||||
auto rnn_output = sequential->forward<RNNOutput>(x);
|
||||
auto expected_output = torch::tensor(
|
||||
{{{0.0000, 0.0000, 0.4886},
|
||||
{0.0000, 0.0000, 0.4886},
|
||||
{0.0000, 0.0000, 0.4886}},
|
||||
{{0.0000, 0.0000, 0.3723},
|
||||
{0.0000, 0.0000, 0.3723},
|
||||
{0.0000, 0.0000, 0.3723}}});
|
||||
ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
|
||||
}
|
||||
{
|
||||
torch::manual_seed(0);
|
||||
Sequential sequential(Identity(), LSTM(2, 3));
|
||||
auto x = torch::ones({2, 3, 2});
|
||||
auto rnn_output = sequential->forward<RNNOutput>(x);
|
||||
auto expected_output = torch::tensor(
|
||||
{{{-0.2693, -0.1240, 0.0744},
|
||||
{-0.2693, -0.1240, 0.0744},
|
||||
{-0.2693, -0.1240, 0.0744}},
|
||||
{{-0.3889, -0.1919, 0.1183},
|
||||
{-0.3889, -0.1919, 0.1183},
|
||||
{-0.3889, -0.1919, 0.1183}}});
|
||||
ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
|
||||
}
|
||||
{
|
||||
torch::manual_seed(0);
|
||||
Sequential sequential(Identity(), GRU(2, 3));
|
||||
auto x = torch::ones({2, 3, 2});
|
||||
auto rnn_output = sequential->forward<RNNOutput>(x);
|
||||
auto expected_output = torch::tensor(
|
||||
{{{-0.1134, 0.0467, 0.2336},
|
||||
{-0.1134, 0.0467, 0.2336},
|
||||
{-0.1134, 0.0467, 0.2336}},
|
||||
{{-0.1189, 0.0502, 0.2960},
|
||||
{-0.1189, 0.0502, 0.2960},
|
||||
{-0.1189, 0.0502, 0.2960}}});
|
||||
ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user