mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
C++ API parity: AdaptiveAvgPool2d
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26818 Test Plan: Imported from OSS Differential Revision: D17627822 Pulled By: pbelevich fbshipit-source-id: 0e1dea1c3ff2650dbc7902ce704ac6b47588d0bb
This commit is contained in:
committed by
Facebook Github Bot
parent
7d58060f49
commit
a31fd5ea68
@ -481,6 +481,50 @@ TEST_F(ModulesTest, AdaptiveAvgPool1d) {
|
||||
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 3}));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, AdaptiveAvgPool2dEven) {
|
||||
AdaptiveAvgPool2d model(3);
|
||||
auto x = torch::arange(0, 50);
|
||||
x.resize_({2, 5, 5}).set_requires_grad(true);
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_TRUE(torch::allclose(y, torch::tensor({
|
||||
{{ 3.0, 4.5, 6.0},
|
||||
{10.5, 12.0, 13.5},
|
||||
{18.0, 19.5, 21.0}},
|
||||
{{28.0, 29.5, 31.0},
|
||||
{35.5, 37.0, 38.5},
|
||||
{43.0, 44.5, 46.0}},
|
||||
}, torch::kFloat)));
|
||||
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 3, 3}));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, AdaptiveAvgPool2dUneven) {
|
||||
AdaptiveAvgPool2d model(AdaptiveAvgPool2dOptions({3, 2}));
|
||||
auto x = torch::arange(0, 40);
|
||||
x.resize_({2, 5, 4}).set_requires_grad(true);
|
||||
auto y = model(x);
|
||||
torch::Tensor s = y.sum();
|
||||
|
||||
s.backward();
|
||||
ASSERT_EQ(s.ndimension(), 0);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 3);
|
||||
ASSERT_TRUE(torch::allclose(y, torch::tensor({
|
||||
{{2.5, 4.5},
|
||||
{8.5, 10.5},
|
||||
{14.5, 16.5}},
|
||||
{{22.5, 24.5},
|
||||
{28.5, 30.5},
|
||||
{34.5, 36.5}},
|
||||
}, torch::kFloat)));
|
||||
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 3, 2}));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Linear) {
|
||||
Linear model(5, 2);
|
||||
auto x = torch::randn({10, 5}, torch::requires_grad());
|
||||
@ -861,6 +905,13 @@ TEST_F(ModulesTest, PrettyPrintAdaptiveAvgPool) {
|
||||
ASSERT_EQ(
|
||||
c10::str(AdaptiveAvgPool1d(5)),
|
||||
"torch::nn::AdaptiveAvgPool1d(output_size=5)");
|
||||
|
||||
ASSERT_EQ(
|
||||
c10::str(AdaptiveAvgPool2d(5)),
|
||||
"torch::nn::AdaptiveAvgPool2d(output_size=[5, 5])");
|
||||
ASSERT_EQ(
|
||||
c10::str(AdaptiveAvgPool2d(torch::IntArrayRef{5, 6})),
|
||||
"torch::nn::AdaptiveAvgPool2d(output_size=[5, 6])");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintDropout) {
|
||||
|
||||
Reference in New Issue
Block a user