mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Add padding='same' mode to conv{1,2,3}d (#45667)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45667 First part of #3867 (Pooling operators still to do) This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input. Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications. Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D27170744 Pulled By: jbschlosser fbshipit-source-id: b3d8a0380e0787ae781f2e5d8ee365a7bfd49f22
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a8a1090324
commit
04e0cbf5a9
@ -55,6 +55,15 @@ TEST_F(ModulesTest, Conv1d) {
|
||||
ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Conv1dSameStrided) {
|
||||
auto options = Conv1dOptions(3, 2, 3);
|
||||
options.stride(1).padding(torch::kSame);
|
||||
Conv1d model_valid(options);
|
||||
ASSERT_THROWS_WITH(
|
||||
[&]{ Conv1d model_invalid(options.stride(2)); }(),
|
||||
"padding='same' is not supported for strided convolutions");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Conv2dEven) {
|
||||
Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false));
|
||||
model->weight.set_data(torch::arange(54, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3}));
|
||||
@ -95,6 +104,18 @@ TEST_F(ModulesTest, Conv2dUneven) {
|
||||
ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 2);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Conv2dSameStrided) {
|
||||
auto options = Conv2dOptions(3, 2, {3, 4});
|
||||
options.stride(1).padding(torch::kSame);
|
||||
Conv2d model_valid(options);
|
||||
ASSERT_THROWS_WITH(
|
||||
[&]{ Conv2d model_invalid(options.stride(2)); }(),
|
||||
"padding='same' is not supported for strided convolutions");
|
||||
ASSERT_THROWS_WITH(
|
||||
[&]{ Conv2d model_invalid(options.stride({1, 2})); }(),
|
||||
"padding='same' is not supported for strided convolutions");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Conv3d) {
|
||||
Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false));
|
||||
model->weight.set_data(torch::arange(162, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3, 3}));
|
||||
@ -131,6 +152,18 @@ TEST_F(ModulesTest, Conv3d) {
|
||||
ASSERT_TRUE(model->weight.grad().numel() == 3 * 2 * 3 * 3 * 3);
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Conv3dSameStrided) {
|
||||
auto options = Conv3dOptions(3, 2, {3, 4, 5});
|
||||
options.stride(1).padding(torch::kSame);
|
||||
Conv3d model_valid(options);
|
||||
ASSERT_THROWS_WITH(
|
||||
[&]{ Conv3d model_invalid(options.stride(2)); }(),
|
||||
"padding='same' is not supported for strided convolutions");
|
||||
ASSERT_THROWS_WITH(
|
||||
[&]{ Conv3d model_invalid(options.stride({1, 2, 1})); }(),
|
||||
"padding='same' is not supported for strided convolutions");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, ConvTranspose1d) {
|
||||
ConvTranspose1d model(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false));
|
||||
model->weight.set_data(torch::arange(18.).view({2, 3, 3}));
|
||||
|
Reference in New Issue
Block a user