Adding nn.ZeroPad1d and nn.ZeroPad3d (#96295)

Fixes #95796

### Implementation
Adds python implementation for `nn.ZeroPad1d` and `nn.ZeroPad3d` in `torch/nn/modules/padding.py`.

Adds cpp implementation for `nn::ZeroPad1d` and `nn::ZeroPad3d` in the following 3 files, refactored with templates similarly to `nn::ConstantPad`'s implementation: <br>
- `torch/crsc/api/include/torch/nn/modules/padding.h`
- `torch/csrc/api/include/torch/nn/options/padding.h`
- `torch/csrc/api/src/nn/modules/padding.cpp`

Also added relevant definitions in `torch/nn/modules/__init__.py`.
### Testing
Adds the following tests:
-  cpp tests of similar length and structure as `ConstantPad` and the existing `ZeroPad2d` impl in `test/cpp/api/modules.cpp`
- cpp API parity tests in `torch/testing/_internal/common_nn.py`
- module init tests in `test/test_module_init.py`

Also added relevant definitions in `test/cpp_api_parity/parity-tracker.md`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96295
Approved by: https://github.com/soulitzer
This commit is contained in:
Rishub Tamirisa
2023-03-10 03:51:37 +00:00
committed by PyTorch MergeBot
parent d0e4ca233e
commit f3b8638074
10 changed files with 345 additions and 42 deletions

View File

@ -4061,6 +4061,27 @@ TEST_F(ModulesTest, ReplicationPad3d) {
}
}
TEST_F(ModulesTest, ZeroPad1d) {
{
ZeroPad1d m(ZeroPad1dOptions(2));
auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
auto output = m(input);
auto expected = torch::tensor(
{{{0., 0., 0., 1., 2., 3., 0., 0.}, {0., 0., 4., 5., 6., 7., 0., 0.}}},
torch::kFloat);
ASSERT_TRUE(output.allclose(expected));
}
{
ZeroPad1d m(ZeroPad1dOptions({3, 1}));
auto input = torch::arange(6, torch::kFloat).reshape({1, 2, 3});
auto output = m(input);
auto expected = torch::tensor(
{{{0., 0., 0., 0., 1., 2., 0.}, {0., 0., 0., 3., 4., 5., 0.}}},
torch::kFloat);
ASSERT_TRUE(output.allclose(expected));
}
}
TEST_F(ModulesTest, ZeroPad2d) {
{
ZeroPad2d m(ZeroPad2dOptions(2));
@ -4092,6 +4113,66 @@ TEST_F(ModulesTest, ZeroPad2d) {
}
}
TEST_F(ModulesTest, ZeroPad3d) {
{
ZeroPad3d m(ZeroPad3dOptions(1));
auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
auto output = m(input);
auto expected = torch::tensor(
{{{{{0., 0., 0., 0.},
{0., 0., 0., 0.},
{0., 0., 0., 0.},
{0., 0., 0., 0.}},
{{0., 0., 0., 0.},
{0., 0., 1., 0.},
{0., 2., 3., 0.},
{0., 0., 0., 0.}},
{{0., 0., 0., 0.},
{0., 4., 5., 0.},
{0., 6., 7., 0.},
{0., 0., 0., 0.}},
{{0., 0., 0., 0.},
{0., 0., 0., 0.},
{0., 0., 0., 0.},
{0., 0., 0., 0.}}}}},
torch::kFloat);
ASSERT_TRUE(output.allclose(expected));
}
{
ZeroPad3d m(ZeroPad3dOptions({1, 2, 1, 2, 1, 2}));
auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
auto output = m(input);
auto expected = torch::tensor(
{{{{{0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0.}},
{{0., 0., 0., 0., 0.},
{0., 0., 1., 0., 0.},
{0., 2., 3., 0., 0.},
{0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0.}},
{{0., 0., 0., 0., 0.},
{0., 4., 5., 0., 0.},
{0., 6., 7., 0., 0.},
{0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0.}},
{{0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0.}},
{{0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0.},
{0., 0., 0., 0., 0.}}}}},
torch::kFloat);
ASSERT_TRUE(output.allclose(expected));
}
}
TEST_F(ModulesTest, ConstantPad1d) {
{
ConstantPad1d m(ConstantPad1dOptions(2, 3.5));
@ -4935,13 +5016,25 @@ TEST_F(ModulesTest, PrettyPrintReplicationPad) {
"torch::nn::ReplicationPad3d(padding=[1, 2, 1, 2, 1, 2])");
}
TEST_F(ModulesTest, PrettyPrintZeroPad2d) {
TEST_F(ModulesTest, PrettyPrintZeroPad) {
ASSERT_EQ(
c10::str(ZeroPad1d(ZeroPad1dOptions(2))),
"torch::nn::ZeroPad1d(padding=[2, 2])");
ASSERT_EQ(
c10::str(ZeroPad1d(ZeroPad1dOptions({3, 1}))),
"torch::nn::ZeroPad1d(padding=[3, 1])");
ASSERT_EQ(
c10::str(ZeroPad2d(ZeroPad2dOptions(2))),
"torch::nn::ZeroPad2d(padding=[2, 2, 2, 2])");
ASSERT_EQ(
c10::str(ZeroPad2d(ZeroPad2dOptions({1, 1, 2, 0}))),
"torch::nn::ZeroPad2d(padding=[1, 1, 2, 0])");
ASSERT_EQ(
c10::str(ZeroPad3d(ZeroPad3dOptions(1))),
"torch::nn::ZeroPad3d(padding=[1, 1, 1, 1, 1, 1])");
ASSERT_EQ(
c10::str(ZeroPad3d(ZeroPad3dOptions({1, 2, 1, 2, 1, 2}))),
"torch::nn::ZeroPad3d(padding=[1, 2, 1, 2, 1, 2])");
}
TEST_F(ModulesTest, PrettyPrintConstantPad) {