C++ Fold nn module

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/24160

Differential Revision: D17260740

Pulled By: yf225

fbshipit-source-id: f0c7769316bed330289ca3d948f2e39c72ec928b
This commit is contained in:
Shahriar
2019-09-10 13:10:16 -07:00
committed by Facebook Github Bot
parent 2ab0f221ba
commit 3680cef44e
9 changed files with 121 additions and 3 deletions

View File

@ -529,6 +529,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/conv.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/dropout.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/embedding.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/fold.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/functional.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/linear.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/named_any.cpp

View File

@ -2194,6 +2194,7 @@ new_module_tests = [
dict(
fullname='Fold',
constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
cpp_constructor_args='(torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1}))',
input_size=(2, 16, 4),
check_gradgrad=False,
test_cuda=True,
@ -2208,6 +2209,7 @@ new_module_tests = [
dict(
fullname='Fold_int_input',
constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
cpp_constructor_args='(torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1))',
input_size=(2, 16, 4),
check_gradgrad=False,
test_cuda=True,

View File

@ -5,6 +5,7 @@
#include <torch/nn/modules/conv.h>
#include <torch/nn/modules/dropout.h>
#include <torch/nn/modules/embedding.h>
#include <torch/nn/modules/fold.h>
#include <torch/nn/modules/functional.h>
#include <torch/nn/modules/linear.h>
#include <torch/types.h>
@ -118,6 +119,21 @@ TEST_F(ModulesTest, Linear) {
ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
}
TEST_F(ModulesTest, Fold) {
Fold model(FoldOptions({4, 5}, {2, 2}));
auto x = torch::randn({1, 3 * 2 * 2, 12});
auto y = model(x);
torch::Tensor s = y.sum();
s.backward();
ASSERT_EQ(y.ndimension(), 4);
ASSERT_EQ(s.ndimension(), 0);
ASSERT_EQ(y.size(0), 1);
ASSERT_EQ(y.size(1), 3);
ASSERT_EQ(y.size(2), 4);
ASSERT_EQ(y.size(3), 5);
}
TEST_F(ModulesTest, SimpleContainer) {
auto model = std::make_shared<SimpleContainer>();
auto l1 = model->add(Linear(10, 3), "l1");
@ -342,7 +358,8 @@ TEST_F(ModulesTest, PrettyPrintConv) {
c10::str(Conv2d(Conv2dOptions(3, 4, 5).stride(2))),
"torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 5], stride=[2, 2])");
const auto options = Conv2dOptions(3, 4, torch::IntArrayRef{5, 6}).stride({1, 2});
const auto options =
Conv2dOptions(3, 4, torch::IntArrayRef{5, 6}).stride({1, 2});
ASSERT_EQ(
c10::str(Conv2d(options)),
"torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 6], stride=[1, 2])");

View File

@ -16,7 +16,7 @@ torch.nn.ConvTranspose1d|No|No
torch.nn.ConvTranspose2d|No|No
torch.nn.ConvTranspose3d|No|No
torch.nn.Unfold|No|No
torch.nn.Fold|No|No
torch.nn.Fold|Yes|No
torch.nn.MaxPool1d|No|No
torch.nn.MaxPool2d|No|No
torch.nn.MaxPool3d|No|No

View File

@ -30,7 +30,10 @@ module_metadata_map = {
'ConvTranspose2d': TorchNNModuleMetadata(),
'ConvTranspose3d': TorchNNModuleMetadata(),
'Unfold': TorchNNModuleMetadata(),
'Fold': TorchNNModuleMetadata(),
'Fold': TorchNNModuleMetadata(
cpp_default_constructor_args="(3, 2)",
num_attrs_recursive=5,
),
'MaxPool1d': TorchNNModuleMetadata(),
'MaxPool2d': TorchNNModuleMetadata(),
'MaxPool3d': TorchNNModuleMetadata(),

View File

@ -189,6 +189,7 @@ def add_torch_libs():
"torch/csrc/api/src/nn/modules/conv.cpp",
"torch/csrc/api/src/nn/modules/dropout.cpp",
"torch/csrc/api/src/nn/modules/embedding.cpp",
"torch/csrc/api/src/nn/modules/fold.cpp",
"torch/csrc/api/src/nn/modules/functional.cpp",
"torch/csrc/api/src/nn/modules/linear.cpp",
"torch/csrc/api/src/nn/modules/named_any.cpp",

View File

@ -5,6 +5,7 @@
#include <torch/nn/modules/conv.h>
#include <torch/nn/modules/dropout.h>
#include <torch/nn/modules/embedding.h>
#include <torch/nn/modules/fold.h>
#include <torch/nn/modules/functional.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/modulelist.h>

View File

@ -0,0 +1,66 @@
#pragma once
#include <torch/expanding_array.h>
#include <torch/nn/cloneable.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
namespace torch {
namespace nn {
/// Options for a fold module.
struct FoldOptions {
FoldOptions(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size)
: output_size_(std::move(output_size)),
kernel_size_(std::move(kernel_size)) {}
/// describes the spatial shape of the large containing tensor of the sliding
/// local blocks. It is useful to resolve the ambiguity when multiple input
/// shapes map to same number of sliding blocks, e.g., with stride > 0.
TORCH_ARG(ExpandingArray<2>, output_size);
/// the size of the sliding blocks
TORCH_ARG(ExpandingArray<2>, kernel_size);
/// controls the spacing between the kernel points; also known as the à trous
/// algorithm.
TORCH_ARG(ExpandingArray<2>, dilation) = 1;
/// controls the amount of implicit zero-paddings on both sides for padding
/// number of points for each dimension before reshaping.
TORCH_ARG(ExpandingArray<2>, padding) = 0;
/// controls the stride for the sliding blocks.
TORCH_ARG(ExpandingArray<2>, stride) = 1;
};
/// Applies fold over a 3-D input.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.Fold to learn about
/// the exact behavior of this module.
class TORCH_API FoldImpl : public torch::nn::Cloneable<FoldImpl> {
public:
FoldImpl(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size)
: FoldImpl(FoldOptions(output_size, kernel_size)) {}
explicit FoldImpl(FoldOptions options) : options(std::move(options)) {}
void reset() override {}
/// Pretty prints the `Fold` module into the given `stream`.
void pretty_print(std::ostream& stream) const override {
stream << "torch::nn::Fold";
}
Tensor forward(const Tensor& input);
/// The options with which this `Module` was constructed.
FoldOptions options;
};
/// A `ModuleHolder` subclass for `FoldImpl`.
/// See the documentation for `FoldImpl` class to learn what methods it
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(Fold);
} // namespace nn
} // namespace torch

View File

@ -0,0 +1,27 @@
#include <torch/nn/modules/fold.h>
#include <torch/expanding_array.h>
#include <torch/types.h>
#include <torch/utils.h>
namespace torch {
namespace nn {
Tensor FoldImpl::forward(const Tensor& input) {
TORCH_CHECK(
input.dim() == 3,
"Input Error: Only 3D input Tensors are supported (got ",
input.dim(),
"D)");
return torch::col2im(
input,
options.output_size_,
options.kernel_size_,
options.dilation_,
options.padding_,
options.stride_);
}
} // namespace nn
} // namespace torch