mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
C++ API parity: MaxUnpool3d
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27027 Test Plan: Imported from OSS Differential Revision: D17682402 Pulled By: pbelevich fbshipit-source-id: 2008ce405176c174cdba88b4f25cd77a82bb13ea
This commit is contained in:
committed by
Facebook Github Bot
parent
f4f6d8dda5
commit
5005f7bce7
@ -651,6 +651,71 @@ TEST_F(ModulesTest, MaxPool2d_MaxUnpool2d) {
|
||||
{ 0, 0, 0, 0, 0}}}}, torch::kFloat)));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, MaxUnpool3d) {
|
||||
auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
|
||||
auto x = torch::tensor({{{{{26}}}}}, torch::requires_grad());
|
||||
auto model = MaxUnpool3d{3};
|
||||
auto y = model->forward(x, indices);
|
||||
|
||||
ASSERT_EQ(y.dim(), 5);
|
||||
ASSERT_TRUE(torch::allclose(y, torch::tensor(
|
||||
{{{{{ 0, 0, 0},
|
||||
{ 0, 0, 0},
|
||||
{ 0, 0, 0}},
|
||||
{{ 0, 0, 0},
|
||||
{ 0, 0, 0},
|
||||
{ 0, 0, 0}},
|
||||
{{ 0, 0, 0},
|
||||
{ 0, 0, 0},
|
||||
{ 0, 0, 26}}}}}, torch::kFloat)));
|
||||
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 3, 3, 3}));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, MaxUnpool3dOutputSize) {
|
||||
auto indices = torch::tensor(
|
||||
{{{{{21, 23},
|
||||
{29, 31}},
|
||||
{{53, 55},
|
||||
{61, 63}}}}}, torch::kLong);
|
||||
auto x = torch::tensor(
|
||||
{{{{{21, 23},
|
||||
{29, 31}},
|
||||
{{53, 55},
|
||||
{61, 63}}}}}, torch::requires_grad());
|
||||
auto model = MaxUnpool3d{MaxUnpool3dOptions(3).stride(2).padding(1)};
|
||||
auto y = model->forward(x, indices, torch::IntArrayRef({1, 1, 4, 4, 4}));
|
||||
|
||||
ASSERT_EQ(y.dim(), 5);
|
||||
ASSERT_TRUE(torch::allclose(y, torch::tensor(
|
||||
{{{{{ 0, 0, 0, 0},
|
||||
{ 0, 0, 0, 0},
|
||||
{ 0, 0, 0, 0},
|
||||
{ 0, 0, 0, 0}},
|
||||
{{ 0, 0, 0, 0},
|
||||
{ 0, 21, 0, 23},
|
||||
{ 0, 0, 0, 0},
|
||||
{ 0, 29, 0, 31}},
|
||||
{{ 0, 0, 0, 0},
|
||||
{ 0, 0, 0, 0},
|
||||
{ 0, 0, 0, 0},
|
||||
{ 0, 0, 0, 0}},
|
||||
{{ 0, 0, 0, 0},
|
||||
{ 0, 53, 0, 55},
|
||||
{ 0, 0, 0, 0},
|
||||
{ 0, 61, 0, 63}}}}}, torch::kFloat)));
|
||||
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 4, 4, 4}));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, MaxPool3d_MaxUnpool3d) {
|
||||
MaxPool3d pool {MaxPool3dOptions(3).stride(2)};
|
||||
MaxUnpool3d unpool {MaxUnpool3dOptions(3).stride(2)};
|
||||
auto input = torch::randn({20, 16, 51, 33, 15});
|
||||
torch::Tensor output, indices;
|
||||
std::tie(output, indices) = pool->forward_with_indices(input);
|
||||
auto unpooled_output = unpool(output, indices);
|
||||
ASSERT_EQ(unpooled_output.sizes(), torch::IntArrayRef({20, 16, 51, 33, 15}));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Linear) {
|
||||
Linear model(5, 2);
|
||||
auto x = torch::randn({10, 5}, torch::requires_grad());
|
||||
|
||||
Reference in New Issue
Block a user