mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add C++ nn::functional pad
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26601 Test Plan: Imported from OSS Differential Revision: D17517468 Pulled By: yf225 fbshipit-source-id: 9ee8b93b88a60f91f2ae78c242f9eaa246b3293c
This commit is contained in:
committed by
Facebook Github Bot
parent
94757e035d
commit
079b3cc02c
@ -576,6 +576,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/dropout.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/linear.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/normalization.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/padding.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/pooling.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/options/rnn.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/adagrad.cpp
|
||||
|
@ -25,7 +25,11 @@ TEST(EnumTest, AllEnums) {
|
||||
torch::enumtype::kReLU,
|
||||
torch::enumtype::kLeakyReLU,
|
||||
torch::enumtype::kFanIn,
|
||||
torch::enumtype::kFanOut
|
||||
torch::enumtype::kFanOut,
|
||||
torch::enumtype::kConstant,
|
||||
torch::enumtype::kReflect,
|
||||
torch::enumtype::kReplicate,
|
||||
torch::enumtype::kCircular
|
||||
> v;
|
||||
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(Linear)
|
||||
@ -41,4 +45,8 @@ TEST(EnumTest, AllEnums) {
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(LeakyReLU)
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(FanIn)
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(FanOut)
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(Constant)
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(Reflect)
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(Replicate)
|
||||
TORCH_ENUM_PRETTY_PRINT_TEST(Circular)
|
||||
}
|
||||
|
@ -1126,3 +1126,144 @@ TEST_F(FunctionalTest, Threshold) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FunctionalTest, Pad) {
|
||||
{
|
||||
auto input = torch::arange(6, torch::kDouble).reshape({1, 2, 3});
|
||||
auto output = F::pad(input, PadOptions({1, 2}).mode(torch::kCircular));
|
||||
auto expected = torch::tensor({{{2., 0., 1., 2., 0., 1.},
|
||||
{5., 3., 4., 5., 3., 4.}}}, torch::kDouble);
|
||||
ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 2, 6}));
|
||||
ASSERT_TRUE(output.allclose(expected, 1e-04));
|
||||
}
|
||||
{
|
||||
auto input = torch::arange(9, torch::kDouble).reshape({1, 1, 3, 3});
|
||||
auto output = F::pad(input, PadOptions({3, 3, 3, 1}).mode(torch::kCircular));
|
||||
auto expected = torch::tensor(
|
||||
{{{{0., 1., 2., 0., 1., 2., 0., 1., 2.},
|
||||
{3., 4., 5., 3., 4., 5., 3., 4., 5.},
|
||||
{6., 7., 8., 6., 7., 8., 6., 7., 8.},
|
||||
{0., 1., 2., 0., 1., 2., 0., 1., 2.},
|
||||
{3., 4., 5., 3., 4., 5., 3., 4., 5.},
|
||||
{6., 7., 8., 6., 7., 8., 6., 7., 8.},
|
||||
{0., 1., 2., 0., 1., 2., 0., 1., 2.}}}}, torch::kDouble);
|
||||
ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 7, 9}));
|
||||
ASSERT_TRUE(output.allclose(expected, 1e-04));
|
||||
}
|
||||
{
|
||||
auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3});
|
||||
auto output = F::pad(input, PadOptions({3, 3, 2, 1, 2, 2}).mode(torch::kCircular));
|
||||
auto expected = torch::tensor(
|
||||
{{{{{ 0., 1., 2., 0., 1., 2., 0., 1., 2.},
|
||||
{ 3., 4., 5., 3., 4., 5., 3., 4., 5.},
|
||||
{ 0., 1., 2., 0., 1., 2., 0., 1., 2.},
|
||||
{ 3., 4., 5., 3., 4., 5., 3., 4., 5.},
|
||||
{ 0., 1., 2., 0., 1., 2., 0., 1., 2.}},
|
||||
|
||||
{{ 6., 7., 8., 6., 7., 8., 6., 7., 8.},
|
||||
{ 9., 10., 11., 9., 10., 11., 9., 10., 11.},
|
||||
{ 6., 7., 8., 6., 7., 8., 6., 7., 8.},
|
||||
{ 9., 10., 11., 9., 10., 11., 9., 10., 11.},
|
||||
{ 6., 7., 8., 6., 7., 8., 6., 7., 8.}},
|
||||
|
||||
{{ 0., 1., 2., 0., 1., 2., 0., 1., 2.},
|
||||
{ 3., 4., 5., 3., 4., 5., 3., 4., 5.},
|
||||
{ 0., 1., 2., 0., 1., 2., 0., 1., 2.},
|
||||
{ 3., 4., 5., 3., 4., 5., 3., 4., 5.},
|
||||
{ 0., 1., 2., 0., 1., 2., 0., 1., 2.}},
|
||||
|
||||
{{ 6., 7., 8., 6., 7., 8., 6., 7., 8.},
|
||||
{ 9., 10., 11., 9., 10., 11., 9., 10., 11.},
|
||||
{ 6., 7., 8., 6., 7., 8., 6., 7., 8.},
|
||||
{ 9., 10., 11., 9., 10., 11., 9., 10., 11.},
|
||||
{ 6., 7., 8., 6., 7., 8., 6., 7., 8.}},
|
||||
|
||||
{{ 0., 1., 2., 0., 1., 2., 0., 1., 2.},
|
||||
{ 3., 4., 5., 3., 4., 5., 3., 4., 5.},
|
||||
{ 0., 1., 2., 0., 1., 2., 0., 1., 2.},
|
||||
{ 3., 4., 5., 3., 4., 5., 3., 4., 5.},
|
||||
{ 0., 1., 2., 0., 1., 2., 0., 1., 2.}},
|
||||
|
||||
{{ 6., 7., 8., 6., 7., 8., 6., 7., 8.},
|
||||
{ 9., 10., 11., 9., 10., 11., 9., 10., 11.},
|
||||
{ 6., 7., 8., 6., 7., 8., 6., 7., 8.},
|
||||
{ 9., 10., 11., 9., 10., 11., 9., 10., 11.},
|
||||
{ 6., 7., 8., 6., 7., 8., 6., 7., 8.}}}}}, torch::kDouble);
|
||||
ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 6, 5, 9}));
|
||||
ASSERT_TRUE(output.allclose(expected, 1e-04));
|
||||
}
|
||||
{
|
||||
auto input = torch::arange(16, torch::kDouble).reshape({2, 2, 2, 2});
|
||||
auto output = F::pad(input, PadOptions({1, 1, 1, 1}).mode(torch::kReflect));
|
||||
auto expected = torch::tensor(
|
||||
{{{{ 3., 2., 3., 2.},
|
||||
{ 1., 0., 1., 0.},
|
||||
{ 3., 2., 3., 2.},
|
||||
{ 1., 0., 1., 0.}},
|
||||
|
||||
{{ 7., 6., 7., 6.},
|
||||
{ 5., 4., 5., 4.},
|
||||
{ 7., 6., 7., 6.},
|
||||
{ 5., 4., 5., 4.}}},
|
||||
|
||||
{{{11., 10., 11., 10.},
|
||||
{ 9., 8., 9., 8.},
|
||||
{11., 10., 11., 10.},
|
||||
{ 9., 8., 9., 8.}},
|
||||
|
||||
{{15., 14., 15., 14.},
|
||||
{13., 12., 13., 12.},
|
||||
{15., 14., 15., 14.},
|
||||
{13., 12., 13., 12.}}}}, torch::kDouble);
|
||||
ASSERT_EQ(output.sizes(), std::vector<int64_t>({2, 2, 4, 4}));
|
||||
ASSERT_TRUE(output.allclose(expected, 1e-04));
|
||||
}
|
||||
{
|
||||
auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3});
|
||||
auto output = F::pad(input, PadOptions({1, 2, 2, 1, 1, 2}).mode(torch::kReplicate));
|
||||
auto expected = torch::tensor(
|
||||
{{{{{ 0., 0., 1., 2., 2., 2.},
|
||||
{ 0., 0., 1., 2., 2., 2.},
|
||||
{ 0., 0., 1., 2., 2., 2.},
|
||||
{ 3., 3., 4., 5., 5., 5.},
|
||||
{ 3., 3., 4., 5., 5., 5.}},
|
||||
|
||||
{{ 0., 0., 1., 2., 2., 2.},
|
||||
{ 0., 0., 1., 2., 2., 2.},
|
||||
{ 0., 0., 1., 2., 2., 2.},
|
||||
{ 3., 3., 4., 5., 5., 5.},
|
||||
{ 3., 3., 4., 5., 5., 5.}},
|
||||
|
||||
{{ 6., 6., 7., 8., 8., 8.},
|
||||
{ 6., 6., 7., 8., 8., 8.},
|
||||
{ 6., 6., 7., 8., 8., 8.},
|
||||
{ 9., 9., 10., 11., 11., 11.},
|
||||
{ 9., 9., 10., 11., 11., 11.}},
|
||||
|
||||
{{ 6., 6., 7., 8., 8., 8.},
|
||||
{ 6., 6., 7., 8., 8., 8.},
|
||||
{ 6., 6., 7., 8., 8., 8.},
|
||||
{ 9., 9., 10., 11., 11., 11.},
|
||||
{ 9., 9., 10., 11., 11., 11.}},
|
||||
|
||||
{{ 6., 6., 7., 8., 8., 8.},
|
||||
{ 6., 6., 7., 8., 8., 8.},
|
||||
{ 6., 6., 7., 8., 8., 8.},
|
||||
{ 9., 9., 10., 11., 11., 11.},
|
||||
{ 9., 9., 10., 11., 11., 11.}}}}}, torch::kDouble);
|
||||
ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 5, 5, 6}));
|
||||
ASSERT_TRUE(output.allclose(expected, 1e-04));
|
||||
}
|
||||
{
|
||||
auto input = torch::ones({1, 1, 1, 1}, torch::kDouble);
|
||||
auto output = F::pad(input, PadOptions({1, 1}).mode(torch::kConstant).value(0));
|
||||
ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 1, 3}));
|
||||
auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble);
|
||||
}
|
||||
{
|
||||
auto input = torch::ones({1, 1, 1, 1}, torch::kDouble);
|
||||
auto output = F::pad(input, PadOptions({1, 1}));
|
||||
ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 1, 3}));
|
||||
auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble);
|
||||
}
|
||||
}
|
||||
|
@ -232,6 +232,7 @@ def add_torch_libs():
|
||||
"torch/csrc/api/src/nn/options/dropout.cpp",
|
||||
"torch/csrc/api/src/nn/options/linear.cpp",
|
||||
"torch/csrc/api/src/nn/options/normalization.cpp",
|
||||
"torch/csrc/api/src/nn/options/padding.cpp",
|
||||
"torch/csrc/api/src/nn/options/pooling.cpp",
|
||||
"torch/csrc/api/src/nn/options/rnn.cpp",
|
||||
"torch/csrc/api/src/optim/adagrad.cpp",
|
||||
|
@ -26,7 +26,7 @@ const enumtype::k##name k##name; \
|
||||
}
|
||||
|
||||
#define TORCH_ENUM_PRETTY_PRINT(name) \
|
||||
const char* operator()(enumtype::k##name& v) const { \
|
||||
const char* operator()(const enumtype::k##name& v) const { \
|
||||
return #name; \
|
||||
}
|
||||
|
||||
@ -43,6 +43,10 @@ TORCH_ENUM_DECLARE(ReLU)
|
||||
TORCH_ENUM_DECLARE(LeakyReLU)
|
||||
TORCH_ENUM_DECLARE(FanIn)
|
||||
TORCH_ENUM_DECLARE(FanOut)
|
||||
TORCH_ENUM_DECLARE(Constant)
|
||||
TORCH_ENUM_DECLARE(Reflect)
|
||||
TORCH_ENUM_DECLARE(Replicate)
|
||||
TORCH_ENUM_DECLARE(Circular)
|
||||
|
||||
namespace torch {
|
||||
namespace enumtype {
|
||||
@ -60,6 +64,10 @@ struct enum_name {
|
||||
TORCH_ENUM_PRETTY_PRINT(LeakyReLU)
|
||||
TORCH_ENUM_PRETTY_PRINT(FanIn)
|
||||
TORCH_ENUM_PRETTY_PRINT(FanOut)
|
||||
TORCH_ENUM_PRETTY_PRINT(Constant)
|
||||
TORCH_ENUM_PRETTY_PRINT(Reflect)
|
||||
TORCH_ENUM_PRETTY_PRINT(Replicate)
|
||||
TORCH_ENUM_PRETTY_PRINT(Circular)
|
||||
};
|
||||
} // namespace enumtype
|
||||
} // namespace torch
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <torch/nn/functional/fold.h>
|
||||
#include <torch/nn/functional/loss.h>
|
||||
#include <torch/nn/functional/normalization.h>
|
||||
#include <torch/nn/functional/padding.h>
|
||||
#include <torch/nn/functional/pixelshuffle.h>
|
||||
#include <torch/nn/functional/pooling.h>
|
||||
#include <torch/nn/functional/vision.h>
|
||||
|
82
torch/csrc/api/include/torch/nn/functional/padding.h
Normal file
82
torch/csrc/api/include/torch/nn/functional/padding.h
Normal file
@ -0,0 +1,82 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/nn/options/padding.h>
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
namespace functional {
|
||||
|
||||
inline Tensor _narrow_with_range(const Tensor& input, int64_t dim, int64_t start, int64_t end) {
|
||||
return input.narrow(dim, start, end - start);
|
||||
}
|
||||
|
||||
inline Tensor _pad_circular(Tensor input, IntArrayRef padding) {
|
||||
input = torch::cat({input, _narrow_with_range(input, 2, 0, padding[-1 + padding.size()])}, /*dim=*/2);
|
||||
input = torch::cat({_narrow_with_range(input, 2, -(padding[-1 + padding.size()] + padding[-2 + padding.size()]), -padding[-1 + padding.size()]), input}, /*dim=*/2);
|
||||
|
||||
if (padding.size() > 2) {
|
||||
input = torch::cat({input, _narrow_with_range(input, 3, 0, padding[-3 + padding.size()])}, /*dim=*/3);
|
||||
input = torch::cat({_narrow_with_range(input, 3, -(padding[-3 + padding.size()] + padding[-4 + padding.size()]), -padding[-3 + padding.size()]), input}, /*dim=*/3);
|
||||
}
|
||||
|
||||
if (padding.size() > 4) {
|
||||
input = torch::cat({input, _narrow_with_range(input, 4, 0, padding[-5 + padding.size()])}, /*dim=*/4);
|
||||
input = torch::cat({_narrow_with_range(input, 4, -(padding[-5 + padding.size()] + padding[-6 + padding.size()]), -padding[-5 + padding.size()]), input}, /*dim=*/4);
|
||||
}
|
||||
|
||||
return input;
|
||||
}
|
||||
|
||||
inline Tensor pad(const Tensor& input, const PadOptions& options) {
|
||||
TORCH_CHECK(options.pad().size() % 2 == 0, "Padding length must be divisible by 2");
|
||||
TORCH_CHECK(((int64_t)(options.pad().size() / 2)) <= input.dim(), "Padding length too large");
|
||||
if (c10::get_if<enumtype::kConstant>(&options.mode())) {
|
||||
return torch::constant_pad_nd(input, options.pad(), options.value());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
options.value() == 0,
|
||||
"Padding mode \"",
|
||||
c10::visit(torch::enumtype::enum_name{}, options.mode()),
|
||||
"\" doesn't take in value argument");
|
||||
if (input.dim() == 3) {
|
||||
TORCH_CHECK(options.pad().size() == 2, "3D tensors expect 2 values for padding");
|
||||
if (c10::get_if<enumtype::kReflect>(&options.mode())) {
|
||||
return torch::reflection_pad1d(input, options.pad());
|
||||
} else if (c10::get_if<enumtype::kReplicate>(&options.mode())) {
|
||||
return torch::replication_pad1d(input, options.pad());
|
||||
} else if (c10::get_if<enumtype::kCircular>(&options.mode())) {
|
||||
return _pad_circular(input, options.pad());
|
||||
} else {
|
||||
TORCH_CHECK(false, "NotImplementedError");
|
||||
}
|
||||
} else if (input.dim() == 4) {
|
||||
TORCH_CHECK(options.pad().size() == 4, "4D tensors expect 4 values for padding");
|
||||
if (c10::get_if<enumtype::kReflect>(&options.mode())) {
|
||||
return torch::reflection_pad2d(input, options.pad());
|
||||
} else if (c10::get_if<enumtype::kReplicate>(&options.mode())) {
|
||||
return torch::replication_pad2d(input, options.pad());
|
||||
} else if (c10::get_if<enumtype::kCircular>(&options.mode())) {
|
||||
return _pad_circular(input, options.pad());
|
||||
} else {
|
||||
TORCH_CHECK(false, "NotImplementedError");
|
||||
}
|
||||
} else if (input.dim() == 5) {
|
||||
TORCH_CHECK(options.pad().size() == 6, "5D tensors expect 6 values for padding");
|
||||
if (c10::get_if<enumtype::kReflect>(&options.mode())) {
|
||||
TORCH_CHECK(false, "NotImplementedError");
|
||||
} else if (c10::get_if<enumtype::kReplicate>(&options.mode())) {
|
||||
return torch::replication_pad3d(input, options.pad());
|
||||
} else if (c10::get_if<enumtype::kCircular>(&options.mode())) {
|
||||
return _pad_circular(input, options.pad());
|
||||
} else {
|
||||
TORCH_CHECK(false, "NotImplementedError");
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Only 3D, 4D, 5D padding with non-constant padding are supported for now");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace functional
|
||||
} // namespace nn
|
||||
} // namespace torch
|
@ -7,6 +7,7 @@
|
||||
#include <torch/nn/options/linear.h>
|
||||
#include <torch/nn/options/loss.h>
|
||||
#include <torch/nn/options/normalization.h>
|
||||
#include <torch/nn/options/padding.h>
|
||||
#include <torch/nn/options/pooling.h>
|
||||
#include <torch/nn/options/rnn.h>
|
||||
#include <torch/nn/options/pixelshuffle.h>
|
||||
|
33
torch/csrc/api/include/torch/nn/options/padding.h
Normal file
33
torch/csrc/api/include/torch/nn/options/padding.h
Normal file
@ -0,0 +1,33 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/variant.h>
|
||||
#include <torch/arg.h>
|
||||
#include <torch/enum.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/expanding_array.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
/// Options for a pad functional.
|
||||
struct TORCH_API PadOptions {
|
||||
PadOptions(std::vector<int64_t> pad);
|
||||
|
||||
/// m-elements tuple, where m/2 <= input dimensions and m is even.
|
||||
TORCH_ARG(std::vector<int64_t>, pad);
|
||||
|
||||
/// "constant", "reflect", "replicate" or "circular". Default: "constant"
|
||||
typedef c10::variant<
|
||||
enumtype::kConstant,
|
||||
enumtype::kReflect,
|
||||
enumtype::kReplicate,
|
||||
enumtype::kCircular> mode_t;
|
||||
TORCH_ARG(mode_t, mode) = torch::kConstant;
|
||||
|
||||
/// fill value for "constant" padding. Default: 0
|
||||
TORCH_ARG(double, value) = 0;
|
||||
};
|
||||
|
||||
} // namespace nn
|
||||
} // namespace torch
|
@ -13,3 +13,7 @@ TORCH_ENUM_DEFINE(ReLU)
|
||||
TORCH_ENUM_DEFINE(LeakyReLU)
|
||||
TORCH_ENUM_DEFINE(FanIn)
|
||||
TORCH_ENUM_DEFINE(FanOut)
|
||||
TORCH_ENUM_DEFINE(Constant)
|
||||
TORCH_ENUM_DEFINE(Reflect)
|
||||
TORCH_ENUM_DEFINE(Replicate)
|
||||
TORCH_ENUM_DEFINE(Circular)
|
||||
|
9
torch/csrc/api/src/nn/options/padding.cpp
Normal file
9
torch/csrc/api/src/nn/options/padding.cpp
Normal file
@ -0,0 +1,9 @@
|
||||
#include <torch/nn/options/padding.h>
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
PadOptions::PadOptions(std::vector<int64_t> pad) : pad_(std::move(pad)) {}
|
||||
|
||||
} // namespace nn
|
||||
} // namespace torch
|
Reference in New Issue
Block a user