C++ API parity: AdaptiveAvgPool2d

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26818

Test Plan: Imported from OSS

Differential Revision: D17627822

Pulled By: pbelevich

fbshipit-source-id: 0e1dea1c3ff2650dbc7902ce704ac6b47588d0bb
This commit is contained in:
Pavel Belevich
2019-09-28 10:43:21 -07:00
committed by Facebook Github Bot
parent 7d58060f49
commit a31fd5ea68
7 changed files with 93 additions and 0 deletions

View File

@ -481,6 +481,50 @@ TEST_F(ModulesTest, AdaptiveAvgPool1d) {
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 3}));
}
TEST_F(ModulesTest, AdaptiveAvgPool2dEven) {
AdaptiveAvgPool2d model(3);
auto x = torch::arange(0, 50);
x.resize_({2, 5, 5}).set_requires_grad(true);
auto y = model(x);
torch::Tensor s = y.sum();
s.backward();
ASSERT_EQ(s.ndimension(), 0);
ASSERT_EQ(y.ndimension(), 3);
ASSERT_TRUE(torch::allclose(y, torch::tensor({
{{ 3.0, 4.5, 6.0},
{10.5, 12.0, 13.5},
{18.0, 19.5, 21.0}},
{{28.0, 29.5, 31.0},
{35.5, 37.0, 38.5},
{43.0, 44.5, 46.0}},
}, torch::kFloat)));
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 3, 3}));
}
TEST_F(ModulesTest, AdaptiveAvgPool2dUneven) {
AdaptiveAvgPool2d model(AdaptiveAvgPool2dOptions({3, 2}));
auto x = torch::arange(0, 40);
x.resize_({2, 5, 4}).set_requires_grad(true);
auto y = model(x);
torch::Tensor s = y.sum();
s.backward();
ASSERT_EQ(s.ndimension(), 0);
ASSERT_EQ(y.ndimension(), 3);
ASSERT_TRUE(torch::allclose(y, torch::tensor({
{{2.5, 4.5},
{8.5, 10.5},
{14.5, 16.5}},
{{22.5, 24.5},
{28.5, 30.5},
{34.5, 36.5}},
}, torch::kFloat)));
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 3, 2}));
}
TEST_F(ModulesTest, Linear) {
Linear model(5, 2);
auto x = torch::randn({10, 5}, torch::requires_grad());
@ -861,6 +905,13 @@ TEST_F(ModulesTest, PrettyPrintAdaptiveAvgPool) {
ASSERT_EQ(
c10::str(AdaptiveAvgPool1d(5)),
"torch::nn::AdaptiveAvgPool1d(output_size=5)");
ASSERT_EQ(
c10::str(AdaptiveAvgPool2d(5)),
"torch::nn::AdaptiveAvgPool2d(output_size=[5, 5])");
ASSERT_EQ(
c10::str(AdaptiveAvgPool2d(torch::IntArrayRef{5, 6})),
"torch::nn::AdaptiveAvgPool2d(output_size=[5, 6])");
}
TEST_F(ModulesTest, PrettyPrintDropout) {