mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Adding nn.ZeroPad1d and nn.ZeroPad3d (#96295)
Fixes #95796 ### Implementation Adds python implementation for `nn.ZeroPad1d` and `nn.ZeroPad3d` in `torch/nn/modules/padding.py`. Adds cpp implementation for `nn::ZeroPad1d` and `nn::ZeroPad3d` in the following 3 files, refactored with templates similarly to `nn::ConstantPad`'s implementation: <br> - `torch/crsc/api/include/torch/nn/modules/padding.h` - `torch/csrc/api/include/torch/nn/options/padding.h` - `torch/csrc/api/src/nn/modules/padding.cpp` Also added relevant definitions in `torch/nn/modules/__init__.py`. ### Testing Adds the following tests: - cpp tests of similar length and structure as `ConstantPad` and the existing `ZeroPad2d` impl in `test/cpp/api/modules.cpp` - cpp API parity tests in `torch/testing/_internal/common_nn.py` - module init tests in `test/test_module_init.py` Also added relevant definitions in `test/cpp_api_parity/parity-tracker.md` Pull Request resolved: https://github.com/pytorch/pytorch/pull/96295 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
d0e4ca233e
commit
f3b8638074
@ -4061,6 +4061,27 @@ TEST_F(ModulesTest, ReplicationPad3d) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, ZeroPad1d) {
|
||||
{
|
||||
ZeroPad1d m(ZeroPad1dOptions(2));
|
||||
auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
|
||||
auto output = m(input);
|
||||
auto expected = torch::tensor(
|
||||
{{{0., 0., 0., 1., 2., 3., 0., 0.}, {0., 0., 4., 5., 6., 7., 0., 0.}}},
|
||||
torch::kFloat);
|
||||
ASSERT_TRUE(output.allclose(expected));
|
||||
}
|
||||
{
|
||||
ZeroPad1d m(ZeroPad1dOptions({3, 1}));
|
||||
auto input = torch::arange(6, torch::kFloat).reshape({1, 2, 3});
|
||||
auto output = m(input);
|
||||
auto expected = torch::tensor(
|
||||
{{{0., 0., 0., 0., 1., 2., 0.}, {0., 0., 0., 3., 4., 5., 0.}}},
|
||||
torch::kFloat);
|
||||
ASSERT_TRUE(output.allclose(expected));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, ZeroPad2d) {
|
||||
{
|
||||
ZeroPad2d m(ZeroPad2dOptions(2));
|
||||
@ -4092,6 +4113,66 @@ TEST_F(ModulesTest, ZeroPad2d) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, ZeroPad3d) {
|
||||
{
|
||||
ZeroPad3d m(ZeroPad3dOptions(1));
|
||||
auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
|
||||
auto output = m(input);
|
||||
auto expected = torch::tensor(
|
||||
{{{{{0., 0., 0., 0.},
|
||||
{0., 0., 0., 0.},
|
||||
{0., 0., 0., 0.},
|
||||
{0., 0., 0., 0.}},
|
||||
{{0., 0., 0., 0.},
|
||||
{0., 0., 1., 0.},
|
||||
{0., 2., 3., 0.},
|
||||
{0., 0., 0., 0.}},
|
||||
{{0., 0., 0., 0.},
|
||||
{0., 4., 5., 0.},
|
||||
{0., 6., 7., 0.},
|
||||
{0., 0., 0., 0.}},
|
||||
{{0., 0., 0., 0.},
|
||||
{0., 0., 0., 0.},
|
||||
{0., 0., 0., 0.},
|
||||
{0., 0., 0., 0.}}}}},
|
||||
torch::kFloat);
|
||||
ASSERT_TRUE(output.allclose(expected));
|
||||
}
|
||||
{
|
||||
ZeroPad3d m(ZeroPad3dOptions({1, 2, 1, 2, 1, 2}));
|
||||
auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
|
||||
auto output = m(input);
|
||||
auto expected = torch::tensor(
|
||||
{{{{{0., 0., 0., 0., 0.},
|
||||
{0., 0., 0., 0., 0.},
|
||||
{0., 0., 0., 0., 0.},
|
||||
{0., 0., 0., 0., 0.},
|
||||
{0., 0., 0., 0., 0.}},
|
||||
{{0., 0., 0., 0., 0.},
|
||||
{0., 0., 1., 0., 0.},
|
||||
{0., 2., 3., 0., 0.},
|
||||
{0., 0., 0., 0., 0.},
|
||||
{0., 0., 0., 0., 0.}},
|
||||
{{0., 0., 0., 0., 0.},
|
||||
{0., 4., 5., 0., 0.},
|
||||
{0., 6., 7., 0., 0.},
|
||||
{0., 0., 0., 0., 0.},
|
||||
{0., 0., 0., 0., 0.}},
|
||||
{{0., 0., 0., 0., 0.},
|
||||
{0., 0., 0., 0., 0.},
|
||||
{0., 0., 0., 0., 0.},
|
||||
{0., 0., 0., 0., 0.},
|
||||
{0., 0., 0., 0., 0.}},
|
||||
{{0., 0., 0., 0., 0.},
|
||||
{0., 0., 0., 0., 0.},
|
||||
{0., 0., 0., 0., 0.},
|
||||
{0., 0., 0., 0., 0.},
|
||||
{0., 0., 0., 0., 0.}}}}},
|
||||
torch::kFloat);
|
||||
ASSERT_TRUE(output.allclose(expected));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, ConstantPad1d) {
|
||||
{
|
||||
ConstantPad1d m(ConstantPad1dOptions(2, 3.5));
|
||||
@ -4935,13 +5016,25 @@ TEST_F(ModulesTest, PrettyPrintReplicationPad) {
|
||||
"torch::nn::ReplicationPad3d(padding=[1, 2, 1, 2, 1, 2])");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintZeroPad2d) {
|
||||
TEST_F(ModulesTest, PrettyPrintZeroPad) {
|
||||
ASSERT_EQ(
|
||||
c10::str(ZeroPad1d(ZeroPad1dOptions(2))),
|
||||
"torch::nn::ZeroPad1d(padding=[2, 2])");
|
||||
ASSERT_EQ(
|
||||
c10::str(ZeroPad1d(ZeroPad1dOptions({3, 1}))),
|
||||
"torch::nn::ZeroPad1d(padding=[3, 1])");
|
||||
ASSERT_EQ(
|
||||
c10::str(ZeroPad2d(ZeroPad2dOptions(2))),
|
||||
"torch::nn::ZeroPad2d(padding=[2, 2, 2, 2])");
|
||||
ASSERT_EQ(
|
||||
c10::str(ZeroPad2d(ZeroPad2dOptions({1, 1, 2, 0}))),
|
||||
"torch::nn::ZeroPad2d(padding=[1, 1, 2, 0])");
|
||||
ASSERT_EQ(
|
||||
c10::str(ZeroPad3d(ZeroPad3dOptions(1))),
|
||||
"torch::nn::ZeroPad3d(padding=[1, 1, 1, 1, 1, 1])");
|
||||
ASSERT_EQ(
|
||||
c10::str(ZeroPad3d(ZeroPad3dOptions({1, 2, 1, 2, 1, 2}))),
|
||||
"torch::nn::ZeroPad3d(padding=[1, 2, 1, 2, 1, 2])");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintConstantPad) {
|
||||
|
Reference in New Issue
Block a user