Add C++ nn::functional pad

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26601

Test Plan: Imported from OSS

Differential Revision: D17517468

Pulled By: yf225

fbshipit-source-id: 9ee8b93b88a60f91f2ae78c242f9eaa246b3293c
This commit is contained in:
Will Feng
2019-10-21 22:16:21 -07:00
committed by Facebook Github Bot
parent 94757e035d
commit 079b3cc02c
11 changed files with 291 additions and 2 deletions

View File

@ -576,6 +576,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/api/src/nn/options/dropout.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/linear.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/normalization.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/padding.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/pooling.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/rnn.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/adagrad.cpp

View File

@ -25,7 +25,11 @@ TEST(EnumTest, AllEnums) {
torch::enumtype::kReLU,
torch::enumtype::kLeakyReLU,
torch::enumtype::kFanIn,
torch::enumtype::kFanOut
torch::enumtype::kFanOut,
torch::enumtype::kConstant,
torch::enumtype::kReflect,
torch::enumtype::kReplicate,
torch::enumtype::kCircular
> v;
TORCH_ENUM_PRETTY_PRINT_TEST(Linear)
@ -41,4 +45,8 @@ TEST(EnumTest, AllEnums) {
TORCH_ENUM_PRETTY_PRINT_TEST(LeakyReLU)
TORCH_ENUM_PRETTY_PRINT_TEST(FanIn)
TORCH_ENUM_PRETTY_PRINT_TEST(FanOut)
TORCH_ENUM_PRETTY_PRINT_TEST(Constant)
TORCH_ENUM_PRETTY_PRINT_TEST(Reflect)
TORCH_ENUM_PRETTY_PRINT_TEST(Replicate)
TORCH_ENUM_PRETTY_PRINT_TEST(Circular)
}

View File

@ -1126,3 +1126,144 @@ TEST_F(FunctionalTest, Threshold) {
}
}
}
TEST_F(FunctionalTest, Pad) {
{
auto input = torch::arange(6, torch::kDouble).reshape({1, 2, 3});
auto output = F::pad(input, PadOptions({1, 2}).mode(torch::kCircular));
auto expected = torch::tensor({{{2., 0., 1., 2., 0., 1.},
{5., 3., 4., 5., 3., 4.}}}, torch::kDouble);
ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 2, 6}));
ASSERT_TRUE(output.allclose(expected, 1e-04));
}
{
auto input = torch::arange(9, torch::kDouble).reshape({1, 1, 3, 3});
auto output = F::pad(input, PadOptions({3, 3, 3, 1}).mode(torch::kCircular));
auto expected = torch::tensor(
{{{{0., 1., 2., 0., 1., 2., 0., 1., 2.},
{3., 4., 5., 3., 4., 5., 3., 4., 5.},
{6., 7., 8., 6., 7., 8., 6., 7., 8.},
{0., 1., 2., 0., 1., 2., 0., 1., 2.},
{3., 4., 5., 3., 4., 5., 3., 4., 5.},
{6., 7., 8., 6., 7., 8., 6., 7., 8.},
{0., 1., 2., 0., 1., 2., 0., 1., 2.}}}}, torch::kDouble);
ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 7, 9}));
ASSERT_TRUE(output.allclose(expected, 1e-04));
}
{
auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3});
auto output = F::pad(input, PadOptions({3, 3, 2, 1, 2, 2}).mode(torch::kCircular));
auto expected = torch::tensor(
{{{{{ 0., 1., 2., 0., 1., 2., 0., 1., 2.},
{ 3., 4., 5., 3., 4., 5., 3., 4., 5.},
{ 0., 1., 2., 0., 1., 2., 0., 1., 2.},
{ 3., 4., 5., 3., 4., 5., 3., 4., 5.},
{ 0., 1., 2., 0., 1., 2., 0., 1., 2.}},
{{ 6., 7., 8., 6., 7., 8., 6., 7., 8.},
{ 9., 10., 11., 9., 10., 11., 9., 10., 11.},
{ 6., 7., 8., 6., 7., 8., 6., 7., 8.},
{ 9., 10., 11., 9., 10., 11., 9., 10., 11.},
{ 6., 7., 8., 6., 7., 8., 6., 7., 8.}},
{{ 0., 1., 2., 0., 1., 2., 0., 1., 2.},
{ 3., 4., 5., 3., 4., 5., 3., 4., 5.},
{ 0., 1., 2., 0., 1., 2., 0., 1., 2.},
{ 3., 4., 5., 3., 4., 5., 3., 4., 5.},
{ 0., 1., 2., 0., 1., 2., 0., 1., 2.}},
{{ 6., 7., 8., 6., 7., 8., 6., 7., 8.},
{ 9., 10., 11., 9., 10., 11., 9., 10., 11.},
{ 6., 7., 8., 6., 7., 8., 6., 7., 8.},
{ 9., 10., 11., 9., 10., 11., 9., 10., 11.},
{ 6., 7., 8., 6., 7., 8., 6., 7., 8.}},
{{ 0., 1., 2., 0., 1., 2., 0., 1., 2.},
{ 3., 4., 5., 3., 4., 5., 3., 4., 5.},
{ 0., 1., 2., 0., 1., 2., 0., 1., 2.},
{ 3., 4., 5., 3., 4., 5., 3., 4., 5.},
{ 0., 1., 2., 0., 1., 2., 0., 1., 2.}},
{{ 6., 7., 8., 6., 7., 8., 6., 7., 8.},
{ 9., 10., 11., 9., 10., 11., 9., 10., 11.},
{ 6., 7., 8., 6., 7., 8., 6., 7., 8.},
{ 9., 10., 11., 9., 10., 11., 9., 10., 11.},
{ 6., 7., 8., 6., 7., 8., 6., 7., 8.}}}}}, torch::kDouble);
ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 6, 5, 9}));
ASSERT_TRUE(output.allclose(expected, 1e-04));
}
{
auto input = torch::arange(16, torch::kDouble).reshape({2, 2, 2, 2});
auto output = F::pad(input, PadOptions({1, 1, 1, 1}).mode(torch::kReflect));
auto expected = torch::tensor(
{{{{ 3., 2., 3., 2.},
{ 1., 0., 1., 0.},
{ 3., 2., 3., 2.},
{ 1., 0., 1., 0.}},
{{ 7., 6., 7., 6.},
{ 5., 4., 5., 4.},
{ 7., 6., 7., 6.},
{ 5., 4., 5., 4.}}},
{{{11., 10., 11., 10.},
{ 9., 8., 9., 8.},
{11., 10., 11., 10.},
{ 9., 8., 9., 8.}},
{{15., 14., 15., 14.},
{13., 12., 13., 12.},
{15., 14., 15., 14.},
{13., 12., 13., 12.}}}}, torch::kDouble);
ASSERT_EQ(output.sizes(), std::vector<int64_t>({2, 2, 4, 4}));
ASSERT_TRUE(output.allclose(expected, 1e-04));
}
{
auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3});
auto output = F::pad(input, PadOptions({1, 2, 2, 1, 1, 2}).mode(torch::kReplicate));
auto expected = torch::tensor(
{{{{{ 0., 0., 1., 2., 2., 2.},
{ 0., 0., 1., 2., 2., 2.},
{ 0., 0., 1., 2., 2., 2.},
{ 3., 3., 4., 5., 5., 5.},
{ 3., 3., 4., 5., 5., 5.}},
{{ 0., 0., 1., 2., 2., 2.},
{ 0., 0., 1., 2., 2., 2.},
{ 0., 0., 1., 2., 2., 2.},
{ 3., 3., 4., 5., 5., 5.},
{ 3., 3., 4., 5., 5., 5.}},
{{ 6., 6., 7., 8., 8., 8.},
{ 6., 6., 7., 8., 8., 8.},
{ 6., 6., 7., 8., 8., 8.},
{ 9., 9., 10., 11., 11., 11.},
{ 9., 9., 10., 11., 11., 11.}},
{{ 6., 6., 7., 8., 8., 8.},
{ 6., 6., 7., 8., 8., 8.},
{ 6., 6., 7., 8., 8., 8.},
{ 9., 9., 10., 11., 11., 11.},
{ 9., 9., 10., 11., 11., 11.}},
{{ 6., 6., 7., 8., 8., 8.},
{ 6., 6., 7., 8., 8., 8.},
{ 6., 6., 7., 8., 8., 8.},
{ 9., 9., 10., 11., 11., 11.},
{ 9., 9., 10., 11., 11., 11.}}}}}, torch::kDouble);
ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 5, 5, 6}));
ASSERT_TRUE(output.allclose(expected, 1e-04));
}
{
auto input = torch::ones({1, 1, 1, 1}, torch::kDouble);
auto output = F::pad(input, PadOptions({1, 1}).mode(torch::kConstant).value(0));
ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 1, 3}));
auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble);
}
{
auto input = torch::ones({1, 1, 1, 1}, torch::kDouble);
auto output = F::pad(input, PadOptions({1, 1}));
ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 1, 3}));
auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble);
}
}

View File

@ -232,6 +232,7 @@ def add_torch_libs():
"torch/csrc/api/src/nn/options/dropout.cpp",
"torch/csrc/api/src/nn/options/linear.cpp",
"torch/csrc/api/src/nn/options/normalization.cpp",
"torch/csrc/api/src/nn/options/padding.cpp",
"torch/csrc/api/src/nn/options/pooling.cpp",
"torch/csrc/api/src/nn/options/rnn.cpp",
"torch/csrc/api/src/optim/adagrad.cpp",

View File

@ -26,7 +26,7 @@ const enumtype::k##name k##name; \
}
#define TORCH_ENUM_PRETTY_PRINT(name) \
const char* operator()(enumtype::k##name& v) const { \
const char* operator()(const enumtype::k##name& v) const { \
return #name; \
}
@ -43,6 +43,10 @@ TORCH_ENUM_DECLARE(ReLU)
TORCH_ENUM_DECLARE(LeakyReLU)
TORCH_ENUM_DECLARE(FanIn)
TORCH_ENUM_DECLARE(FanOut)
TORCH_ENUM_DECLARE(Constant)
TORCH_ENUM_DECLARE(Reflect)
TORCH_ENUM_DECLARE(Replicate)
TORCH_ENUM_DECLARE(Circular)
namespace torch {
namespace enumtype {
@ -60,6 +64,10 @@ struct enum_name {
TORCH_ENUM_PRETTY_PRINT(LeakyReLU)
TORCH_ENUM_PRETTY_PRINT(FanIn)
TORCH_ENUM_PRETTY_PRINT(FanOut)
TORCH_ENUM_PRETTY_PRINT(Constant)
TORCH_ENUM_PRETTY_PRINT(Reflect)
TORCH_ENUM_PRETTY_PRINT(Replicate)
TORCH_ENUM_PRETTY_PRINT(Circular)
};
} // namespace enumtype
} // namespace torch

View File

@ -5,6 +5,7 @@
#include <torch/nn/functional/fold.h>
#include <torch/nn/functional/loss.h>
#include <torch/nn/functional/normalization.h>
#include <torch/nn/functional/padding.h>
#include <torch/nn/functional/pixelshuffle.h>
#include <torch/nn/functional/pooling.h>
#include <torch/nn/functional/vision.h>

View File

@ -0,0 +1,82 @@
#pragma once
#include <torch/nn/options/padding.h>
namespace torch {
namespace nn {
namespace functional {
inline Tensor _narrow_with_range(const Tensor& input, int64_t dim, int64_t start, int64_t end) {
return input.narrow(dim, start, end - start);
}
inline Tensor _pad_circular(Tensor input, IntArrayRef padding) {
input = torch::cat({input, _narrow_with_range(input, 2, 0, padding[-1 + padding.size()])}, /*dim=*/2);
input = torch::cat({_narrow_with_range(input, 2, -(padding[-1 + padding.size()] + padding[-2 + padding.size()]), -padding[-1 + padding.size()]), input}, /*dim=*/2);
if (padding.size() > 2) {
input = torch::cat({input, _narrow_with_range(input, 3, 0, padding[-3 + padding.size()])}, /*dim=*/3);
input = torch::cat({_narrow_with_range(input, 3, -(padding[-3 + padding.size()] + padding[-4 + padding.size()]), -padding[-3 + padding.size()]), input}, /*dim=*/3);
}
if (padding.size() > 4) {
input = torch::cat({input, _narrow_with_range(input, 4, 0, padding[-5 + padding.size()])}, /*dim=*/4);
input = torch::cat({_narrow_with_range(input, 4, -(padding[-5 + padding.size()] + padding[-6 + padding.size()]), -padding[-5 + padding.size()]), input}, /*dim=*/4);
}
return input;
}
inline Tensor pad(const Tensor& input, const PadOptions& options) {
TORCH_CHECK(options.pad().size() % 2 == 0, "Padding length must be divisible by 2");
TORCH_CHECK(((int64_t)(options.pad().size() / 2)) <= input.dim(), "Padding length too large");
if (c10::get_if<enumtype::kConstant>(&options.mode())) {
return torch::constant_pad_nd(input, options.pad(), options.value());
} else {
TORCH_CHECK(
options.value() == 0,
"Padding mode \"",
c10::visit(torch::enumtype::enum_name{}, options.mode()),
"\" doesn't take in value argument");
if (input.dim() == 3) {
TORCH_CHECK(options.pad().size() == 2, "3D tensors expect 2 values for padding");
if (c10::get_if<enumtype::kReflect>(&options.mode())) {
return torch::reflection_pad1d(input, options.pad());
} else if (c10::get_if<enumtype::kReplicate>(&options.mode())) {
return torch::replication_pad1d(input, options.pad());
} else if (c10::get_if<enumtype::kCircular>(&options.mode())) {
return _pad_circular(input, options.pad());
} else {
TORCH_CHECK(false, "NotImplementedError");
}
} else if (input.dim() == 4) {
TORCH_CHECK(options.pad().size() == 4, "4D tensors expect 4 values for padding");
if (c10::get_if<enumtype::kReflect>(&options.mode())) {
return torch::reflection_pad2d(input, options.pad());
} else if (c10::get_if<enumtype::kReplicate>(&options.mode())) {
return torch::replication_pad2d(input, options.pad());
} else if (c10::get_if<enumtype::kCircular>(&options.mode())) {
return _pad_circular(input, options.pad());
} else {
TORCH_CHECK(false, "NotImplementedError");
}
} else if (input.dim() == 5) {
TORCH_CHECK(options.pad().size() == 6, "5D tensors expect 6 values for padding");
if (c10::get_if<enumtype::kReflect>(&options.mode())) {
TORCH_CHECK(false, "NotImplementedError");
} else if (c10::get_if<enumtype::kReplicate>(&options.mode())) {
return torch::replication_pad3d(input, options.pad());
} else if (c10::get_if<enumtype::kCircular>(&options.mode())) {
return _pad_circular(input, options.pad());
} else {
TORCH_CHECK(false, "NotImplementedError");
}
} else {
TORCH_CHECK(false, "Only 3D, 4D, 5D padding with non-constant padding are supported for now");
}
}
}
} // namespace functional
} // namespace nn
} // namespace torch

View File

@ -7,6 +7,7 @@
#include <torch/nn/options/linear.h>
#include <torch/nn/options/loss.h>
#include <torch/nn/options/normalization.h>
#include <torch/nn/options/padding.h>
#include <torch/nn/options/pooling.h>
#include <torch/nn/options/rnn.h>
#include <torch/nn/options/pixelshuffle.h>

View File

@ -0,0 +1,33 @@
#pragma once
#include <c10/util/variant.h>
#include <torch/arg.h>
#include <torch/enum.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/expanding_array.h>
#include <torch/types.h>
namespace torch {
namespace nn {
/// Options for a pad functional.
struct TORCH_API PadOptions {
PadOptions(std::vector<int64_t> pad);
/// m-elements tuple, where m/2 <= input dimensions and m is even.
TORCH_ARG(std::vector<int64_t>, pad);
/// "constant", "reflect", "replicate" or "circular". Default: "constant"
typedef c10::variant<
enumtype::kConstant,
enumtype::kReflect,
enumtype::kReplicate,
enumtype::kCircular> mode_t;
TORCH_ARG(mode_t, mode) = torch::kConstant;
/// fill value for "constant" padding. Default: 0
TORCH_ARG(double, value) = 0;
};
} // namespace nn
} // namespace torch

View File

@ -13,3 +13,7 @@ TORCH_ENUM_DEFINE(ReLU)
TORCH_ENUM_DEFINE(LeakyReLU)
TORCH_ENUM_DEFINE(FanIn)
TORCH_ENUM_DEFINE(FanOut)
TORCH_ENUM_DEFINE(Constant)
TORCH_ENUM_DEFINE(Reflect)
TORCH_ENUM_DEFINE(Replicate)
TORCH_ENUM_DEFINE(Circular)

View File

@ -0,0 +1,9 @@
#include <torch/nn/options/padding.h>
namespace torch {
namespace nn {
PadOptions::PadOptions(std::vector<int64_t> pad) : pad_(std::move(pad)) {}
} // namespace nn
} // namespace torch