Port dilated_max_pool2d() to ATen

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20691

Differential Revision: D15435960

Pulled By: ezyang

fbshipit-source-id: 548b7cc42e52ad2c641ec7d9cf78028d9411d02e
This commit is contained in:
Stefan Krah
2019-05-23 08:59:32 -07:00
committed by Facebook Github Bot
parent f039401bf2
commit ec57d1f18a
16 changed files with 1044 additions and 830 deletions

View File

@ -0,0 +1,112 @@
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/NativeFunctions.h>
#include <tuple>
#pragma once
namespace at {
namespace native {
namespace {
template <typename dest_t, typename src_t>
static inline dest_t
safe_downcast(src_t v)
{
TORCH_CHECK(std::numeric_limits<dest_t>::min() <= v && v <= std::numeric_limits<dest_t>::max(),
"integer out of range");
return static_cast<dest_t>(v);
}
template<typename T>
static inline T pooling_output_shape(
T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) {
T outputSize = ((inputSize + 2 * pad - dilation * (kernelSize - 1) - 1 + (ceil_mode ? stride - 1 : 0)) / stride + 1);
if (pad) {
// ensure that the last pooling starts inside the image
// needed to avoid problems in ceil mode
if ((outputSize - 1) * stride >= inputSize + pad)
--outputSize;
}
return outputSize;
}
static inline void
max_pool2d_with_indices_shape_check(
const Tensor& input,
int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
int64_t nInputPlane,
int64_t inputHeight, int64_t inputWidth,
int64_t outputHeight, int64_t outputWidth)
{
const int64_t ndim = input.ndimension();
const int64_t nOutputPlane = nInputPlane;
TORCH_CHECK(kW > 0 && kH > 0,
"kernel size should be greater than zero, but got ",
"kH: ", kH, " kW: ", kW);
TORCH_CHECK(dW > 0 && dH > 0,
"stride should be greater than zero, but got "
"dH: ", dH, " dW: ", dW);
TORCH_CHECK(dilationH > 0 && dilationW > 0,
"dilation should be greater than zero, but got ",
"dilationH: ", dilationH, " dilationW: ", dilationW);
TORCH_CHECK(input.numel() > 0 && (ndim == 3 || ndim == 4),
"non-empty 3D or 4D input tensor expected but got ndim: ", ndim);
TORCH_CHECK(kW/2 >= padW && kH/2 >= padH,
"pad should be smaller than half of kernel size, but got ",
"padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH);
if (outputWidth < 1 || outputHeight < 1) {
AT_ERROR("Given input size: (",
nInputPlane, "x", inputHeight, "x", inputWidth, "). ",
"Calculated output size: (",
nOutputPlane, "x", outputHeight, "x", outputWidth, "). ",
"Output size is too small");
}
}
static inline void
max_pool2d_with_indices_shape_check(
const Tensor& input,
const Tensor& gradOutput,
const Tensor& indices,
int64_t nbatch,
int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
int64_t nInputPlane,
int64_t inputHeight, int64_t inputWidth,
int64_t outputHeight, int64_t outputWidth,
bool cuda=false)
{
max_pool2d_with_indices_shape_check(
input,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
const int64_t ndim = input.ndimension();
const int64_t nOutputPlane = nInputPlane;
check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
if (cuda) {
check_dim_size(indices, 4, 0, nbatch);
check_dim_size(indices, 4, 1, nOutputPlane);
check_dim_size(indices, 4, 2, outputHeight);
check_dim_size(indices, 4, 3, outputWidth);
}
else {
check_dim_size(indices, ndim, ndim-3, nOutputPlane);
check_dim_size(indices, ndim, ndim-2, outputHeight);
check_dim_size(indices, ndim, ndim-1, outputWidth);
}
}
} // namespace
} // at::native
} // at

View File

@ -0,0 +1,499 @@
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/DilatedMaxPool.h>
#include <tuple>
namespace at {
namespace native {
namespace {
template <typename scalar_t>
static void max_pool2d_with_indices_single_out_frame(
scalar_t *input_p,
scalar_t *output_p,
int64_t *ind_p,
int64_t nslices,
int64_t iwidth,
int64_t iheight,
int64_t owidth,
int64_t oheight,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH
)
{
at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
for (auto k = start; k < end; k++)
{
/* loop over output */
int64_t i, j;
scalar_t *ip = input_p + k*iwidth*iheight;
for(i = 0; i < oheight; i++)
{
for(j = 0; j < owidth; j++)
{
int64_t hstart = i * dH - padH;
int64_t wstart = j * dW - padW;
int64_t hend = std::min(hstart + (kH - 1) * dilationH + 1, iheight);
int64_t wend = std::min(wstart + (kW - 1) * dilationW + 1, iwidth);
while(hstart < 0)
hstart += dilationH;
while(wstart < 0)
wstart += dilationW;
/* local pointers */
scalar_t *op = output_p + k*owidth*oheight + i*owidth + j;
int64_t *indp = ind_p + k*owidth*oheight + i*owidth + j;
/* compute local max: */
int64_t maxindex = -1;
scalar_t maxval = -std::numeric_limits<scalar_t>::max();
int64_t tcntr = 0;
int64_t x,y;
for(y = hstart; y < hend; y += dilationH)
{
for(x = wstart; x < wend; x += dilationW)
{
tcntr = y*iwidth + x;
scalar_t val = *(ip + tcntr);
if ((val > maxval) || std::isnan(val))
{
maxval = val;
maxindex = tcntr;
}
}
}
/* set output to local max */
*op = maxval;
/* store location of max */
*indp = maxindex;
}
}
}
});
}
template <typename scalar_t>
static void max_pool2d_with_indices_out_frame(
scalar_t *input_data,
scalar_t *output_data,
int64_t *indices_data,
int64_t nbatch,
int64_t nInputPlane,
int64_t inputWidth,
int64_t inputHeight,
int64_t outputWidth,
int64_t outputHeight,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH)
{
at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
for (auto p = start; p < end; p++) {
max_pool2d_with_indices_single_out_frame(
input_data+p*nInputPlane*inputWidth*inputHeight,
output_data+p*nInputPlane*outputWidth*outputHeight,
indices_data+p*nInputPlane*outputWidth*outputHeight,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
kW, kH, dW, dH,
padW, padH,
dilationW, dilationH);
}
});
}
void max_pool2d_with_indices_out_cpu_template(
Tensor& output,
Tensor& indices,
const Tensor& input_,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
// XXX JIT: Pooling.cpp allows stride.empty().
// XXX IntegrationTest.MNIST: padding.size() == 1 && dilation.size() == 1.
TORCH_CHECK(kernel_size.size() == 2 &&
(stride.empty() || stride.size() == 2) &&
(padding.size() == 1 || padding.size() == 2) &&
(dilation.size() == 1 || dilation.size() == 2),
"max_pool2d_with_indices: internal error: all IntArrayRef sizes must be 2");
TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input");
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = safe_downcast<int, int64_t>(kernel_size[1]);
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[1]);
const int padH = safe_downcast<int, int64_t>(padding[0]);
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
/* sizes */
const int64_t nbatch = input_.ndimension() == 4 ? input_.size(-4) : 1;
const int64_t nInputPlane = input_.size(-3);
const int64_t inputHeight = input_.size(-2);
const int64_t inputWidth = input_.size(-1);
const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
max_pool2d_with_indices_shape_check(
input_,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
nInputPlane,
inputHeight, inputWidth,
outputHeight, outputWidth);
/* get contiguous input */
Tensor input = input_.contiguous();
/* resize output */
if (input.ndimension() == 3)
{
output.resize_({nInputPlane, outputHeight, outputWidth});
/* indices will contain the locations for each output point */
indices.resize_({nInputPlane, outputHeight, outputWidth});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
"max_pool2d_with_indices_cpu",
[&] {
/* get raw pointers */
scalar_t *input_data = input.data<scalar_t>();
scalar_t *output_data = output.data<scalar_t>();
int64_t *indices_data = indices.data<int64_t>();
max_pool2d_with_indices_single_out_frame(
input_data, output_data,
indices_data,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
kW, kH, dW, dH,
padW, padH,
dilationW, dilationH);
}
);
}
else
{
output.resize_({nbatch, nInputPlane, outputHeight, outputWidth});
/* indices will contain the locations for each output point */
indices.resize_({nbatch, nInputPlane, outputHeight, outputWidth});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
"max_pool2d_with_indices_cpu",
[&] {
scalar_t *input_data = input.data<scalar_t>();
scalar_t *output_data = output.data<scalar_t>();
int64_t *indices_data = indices.data<int64_t>();
max_pool2d_with_indices_out_frame(
input_data,
output_data,
indices_data,
nbatch,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
kW, kH, dW, dH,
padW, padH,
dilationW, dilationH); }
);
}
}
template <typename scalar_t>
static void max_pool2d_with_indices_backward_single_out_frame(
scalar_t *gradInput_p,
scalar_t *gradOutput_p,
int64_t *ind_p,
int64_t nInputPlane,
int64_t inputWidth,
int64_t inputHeight,
int64_t outputWidth,
int64_t outputHeight,
int dW,
int dH)
{
at::parallel_for(0, nInputPlane, 0, [&](int64_t start, int64_t end) {
for (auto k = start; k < end; k++)
{
scalar_t *gradInput_p_k = gradInput_p + k*inputWidth*inputHeight;
scalar_t *gradOutput_p_k = gradOutput_p + k*outputWidth*outputHeight;
int64_t *ind_p_k = ind_p + k*outputWidth*outputHeight;
/* calculate max points */
int64_t i, j;
for(i = 0; i < outputHeight; i++)
{
for(j = 0; j < outputWidth; j++)
{
/* retrieve position of max */
int64_t maxp = ind_p_k[i*outputWidth + j];
if (maxp != -1) {
/* update gradient */
gradInput_p_k[maxp] += gradOutput_p_k[i*outputWidth + j];
}
}
}
}
});
}
template <typename scalar_t>
static void max_pool2d_with_indices_backward_out_frame(
scalar_t *gradInput_data,
scalar_t *gradOutput_data,
int64_t *indices_data,
int64_t nbatch,
int64_t nInputPlane,
int64_t inputWidth,
int64_t inputHeight,
int64_t outputWidth,
int64_t outputHeight,
int dW,
int dH)
{
at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
for (auto p = start; p < end; p++) {
max_pool2d_with_indices_backward_single_out_frame<scalar_t>(
gradInput_data+p*nInputPlane*inputWidth*inputHeight,
gradOutput_data+p*nInputPlane*outputWidth*outputHeight,
indices_data+p*nInputPlane*outputWidth*outputHeight,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
dW, dH);
}
});
}
Tensor& max_pool2d_with_indices_backward_out_cpu_template(
Tensor& gradInput,
const Tensor& gradOutput_,
const Tensor& input,
const Tensor& indices,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
// XXX JIT: Pooling.cpp allows stride.empty().
// XXX IntegrationTest.MNIST: padding.size() == 1 && dilation.size() == 1.
TORCH_CHECK(kernel_size.size() == 2 &&
(stride.empty() || stride.size() == 2) &&
(padding.size() == 1 || padding.size() == 2) &&
(dilation.size() == 1 || dilation.size() == 2),
"max_pool2d_with_indices: internal error: all IntArrayRef sizes must be 2");
TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input");
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = safe_downcast<int, int64_t>(kernel_size[1]);
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[1]);
const int padH = safe_downcast<int, int64_t>(padding[0]);
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
/* get contiguous gradOutput */
const Tensor gradOutput = gradOutput_.contiguous();
/* resize */
gradInput.resize_as_(input);
gradInput.zero_();
/* sizes */
const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
const int64_t nInputPlane = input.size(-3);
const int64_t inputHeight = input.size(-2);
const int64_t inputWidth = input.size(-1);
const int64_t outputHeight = gradOutput.size(-2);
const int64_t outputWidth = gradOutput.size(-1);
/* XXX preserve the existing shape check behavior */
const int64_t outputHeight_for_shape_check = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
const int64_t outputWidth_for_shape_check = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
max_pool2d_with_indices_shape_check(
input,
gradOutput_,
indices,
nbatch,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
nInputPlane,
inputHeight, inputWidth,
outputHeight_for_shape_check, outputWidth_for_shape_check);
/* backprop */
if (input.ndimension() == 3)
{
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
"max_pool2d_with_indices_backward",
[&] {
/* get raw pointers */
scalar_t *gradInput_data = gradInput.data<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
int64_t *indices_data = indices.data<int64_t>();
max_pool2d_with_indices_backward_single_out_frame(
gradInput_data, gradOutput_data,
indices_data,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
dW, dH);
}
);
}
else
{
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
"max_pool2d_with_indices_backward",
[&] {
/* get raw pointers */
scalar_t *gradInput_data = gradInput.data<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
int64_t *indices_data = indices.data<int64_t>();
max_pool2d_with_indices_backward_out_frame<scalar_t>(
gradInput_data, gradOutput_data,
indices_data,
nbatch,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
dW, dH);
}
);
}
return gradInput;
}
} // namespace
std::tuple<Tensor& ,Tensor&> max_pool2d_with_indices_out_cpu(
Tensor& output,
Tensor& indices,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
max_pool2d_with_indices_out_cpu_template(
output,
indices,
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return std::tuple<Tensor&, Tensor&>(output, indices);
}
std::tuple<Tensor ,Tensor> max_pool2d_with_indices_cpu(
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
Tensor output = at::empty({0}, input.options());
Tensor indices = at::empty({0}, input.options().dtype(kLong));
max_pool2d_with_indices_out_cpu_template(
output,
indices,
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return std::tuple<Tensor&, Tensor&>(output, indices);
}
Tensor& max_pool2d_with_indices_backward_out_cpu(
Tensor& gradInput,
const Tensor& gradOutput_,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& indices)
{
max_pool2d_with_indices_backward_out_cpu_template(
gradInput,
gradOutput_,
input,
indices,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return gradInput;
}
Tensor max_pool2d_with_indices_backward_cpu(
const Tensor& gradOutput_,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& indices)
{
auto gradInput = at::zeros_like(input);
max_pool2d_with_indices_backward_out_cpu_template(
gradInput,
gradOutput_,
input,
indices,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return gradInput;
}
} // at::native
} // at

View File

@ -364,22 +364,6 @@ Tensor avg_pool3d_backward(const Tensor & grad_output, const Tensor & self, IntA
return at::legacy::th::_thnn_avg_pool3d_backward(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad);
}
std::tuple<Tensor &,Tensor &> max_pool2d_with_indices_out(Tensor & output, Tensor & indices, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
return at::legacy::th::_thnn_max_pool2d_with_indices_forward_out(output, indices, self, kernel_size, stride, padding, dilation, ceil_mode);
}
std::tuple<Tensor,Tensor> max_pool2d_with_indices(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
return at::legacy::th::_thnn_max_pool2d_with_indices_forward(self, kernel_size, stride, padding, dilation, ceil_mode);
}
Tensor & max_pool2d_with_indices_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, const Tensor & indices) {
return at::legacy::th::_thnn_max_pool2d_with_indices_backward_out(grad_input, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices);
}
Tensor max_pool2d_with_indices_backward(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, const Tensor & indices) {
return at::legacy::th::_thnn_max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices);
}
std::tuple<Tensor &,Tensor &> max_pool3d_with_indices_out(Tensor & output, Tensor & indices, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
return at::legacy::th::_thnn_max_pool3d_with_indices_forward_out(output, indices, self, kernel_size, stride, padding, dilation, ceil_mode);
}

View File

@ -0,0 +1,420 @@
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/native/DilatedMaxPool.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <THC/THCNumerics.cuh>
#include <c10/macros/Macros.h>
namespace at {
namespace native {
namespace {
__device__ inline int min(int a, int b) {
return a <= b ? a : b;
}
// kernels borrowed from Caffe
template <typename scalar_t, typename accscalar_t>
__global__ void MaxPoolForward(const int nthreads, const scalar_t* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, scalar_t* top_data,
int64_t* top_mask) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
while(hstart < 0)
hstart += dilation_h;
while(wstart < 0)
wstart += dilation_w;
accscalar_t maxval = THCNumerics<accscalar_t>::min();
int maxidx = -1;
bottom_data += (n * channels + c) * height * width;
for (int h = hstart; h < hend; h += dilation_h) {
for (int w = wstart; w < wend; w += dilation_w) {
scalar_t val = bottom_data[h * width + w];
if ((ScalarConvert<scalar_t, accscalar_t>::to(val) > maxval) || THCNumerics<scalar_t>::isnan(val)) {
maxidx = h * width + w;
maxval = ScalarConvert<scalar_t, accscalar_t>::to(val);
}
}
}
top_data[index] = ScalarConvert<scalar_t, accscalar_t>::to(maxval);
top_mask[index] = maxidx;
}
}
static const int BACKWARD_THREADS = 256;
template <typename scalar_t, typename accscalar_t>
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2(BACKWARD_THREADS, 4)
#else
C10_LAUNCH_BOUNDS_2(BACKWARD_THREADS, 8)
#endif
__global__ void MaxPoolBackward(const int nthreads, const scalar_t* top_diff,
const int64_t* top_mask, const int num, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
scalar_t* bottom_diff) {
CUDA_KERNEL_LOOP(index, height*width) {
int h = index/width;
int w = index - h * width;
//get some templating performance benefits without actually templating
int phstart, phend, pwstart, pwend;
if (stride_h == 1) {
phstart =
(h + pad_h < ((kernel_h - 1) * dilation_h + 1)) ? 0 : (h + pad_h - ((kernel_h - 1) * dilation_h + 1)) + 1;
phend = min((h + pad_h) + 1, pooled_height);
} else if (stride_h == 2) {
phstart =
(h + pad_h < ((kernel_h - 1) * dilation_h + 1)) ? 0 : (h + pad_h - ((kernel_h - 1) * dilation_h + 1)) / 2 + 1;
phend = min((h + pad_h) / 2 + 1, pooled_height);
} else {
phstart =
(h + pad_h < ((kernel_h - 1) * dilation_h + 1)) ? 0 : (h + pad_h - ((kernel_h - 1) * dilation_h + 1)) / stride_h + 1;
phend = min((h + pad_h) / stride_h + 1, pooled_height);
}
if (stride_w == 1) {
pwstart =
(w + pad_w < ((kernel_w - 1) * dilation_w + 1)) ? 0 : (w + pad_w - ((kernel_w - 1) * dilation_w + 1)) + 1;
pwend = min((w + pad_w) + 1, pooled_width);
} else if (stride_w == 2) {
pwstart =
(w + pad_w < ((kernel_w - 1) * dilation_w + 1)) ? 0 : (w + pad_w - ((kernel_w - 1) * dilation_w + 1)) / 2 + 1;
pwend = min((w + pad_w) / 2 + 1, pooled_width);
} else {
pwstart =
(w + pad_w < ((kernel_w - 1) * dilation_w + 1)) ? 0 : (w + pad_w - ((kernel_w - 1) * dilation_w + 1)) / stride_w + 1;
pwend = min((w + pad_w) / stride_w + 1, pooled_width);
}
for (int n = blockIdx.y; n < num; n += gridDim.y)
for (int c = blockIdx.z; c < channels; c+= gridDim.z) {
accscalar_t gradient = accscalar_t(0);
int offset = (n * channels + c) * pooled_height * pooled_width;
top_diff += offset;
top_mask += offset;
//get some templating performance benefits without actually templating
if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) {
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (top_mask[ph * pooled_width + pw] == h * width + w) {
gradient += ScalarConvert<scalar_t, accscalar_t>::to(top_diff[ph * pooled_width + pw]);
}
}
}
} else {
if (top_mask[phstart * pooled_width + pwstart] == h * width + w) {
gradient += ScalarConvert<scalar_t, accscalar_t>::to(top_diff[phstart * pooled_width + pwstart]);
}
}
bottom_diff[(n*channels+c)*height*width+index] = ScalarConvert<accscalar_t, scalar_t>::to(gradient);
}
}
}
void max_pool2d_with_indices_out_cuda_template(
Tensor& output,
Tensor& indices,
const Tensor& input_,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
TensorArg output_arg{ output, "output", 1 };
TensorArg indices_arg{ indices, "indices", 2 };
TensorArg input_arg{ input_, "input_", 3 };
checkAllSameGPU("max_pool2d_with_indices_out_cuda",
{output_arg, indices_arg, input_arg});
// XXX JIT: Pooling.cpp allows stride.empty().
// XXX IntegrationTest.MNIST: padding.size() == 1 && dilation.size() == 1.
TORCH_CHECK(kernel_size.size() == 2 &&
(stride.empty() || stride.size() == 2) &&
(padding.size() == 1 || padding.size() == 2) &&
(dilation.size() == 1 || dilation.size() == 2),
"max_pool2d_with_indices: internal error: all IntArrayRef sizes must be 2");
TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input");
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = safe_downcast<int, int64_t>(kernel_size[1]);
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[1]);
const int padH = safe_downcast<int, int64_t>(padding[0]);
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
const int64_t nbatch = input_.ndimension() == 4 ? input_.size(-4) : 1;
const int64_t nInputPlane = input_.size(-3);
const int64_t inputHeight = input_.size(-2);
const int64_t inputWidth = input_.size(-1);
const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
max_pool2d_with_indices_shape_check(
input_,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
nInputPlane,
inputHeight, inputWidth,
outputHeight, outputWidth);
Tensor input = input_.contiguous();
output.resize_({nbatch, nInputPlane, outputHeight, outputWidth});
indices.resize_({nbatch, nInputPlane, outputHeight, outputWidth});
const int count = safe_downcast<int, int64_t>(output.numel());
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
BACKWARD_THREADS);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
"max_pool2d_with_indices_out_cuda_frame",
[&] {
using accscalar_t = acc_type<scalar_t, true>;
scalar_t *output_data = output.data<scalar_t>();
scalar_t *input_data = input.data<scalar_t>();
int64_t *indices_data = indices.data<int64_t>();
MaxPoolForward<scalar_t, scalar_t>
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
count, input_data,
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW, output_data, indices_data); }
);
TORCH_CHECK(cudaGetLastError() == cudaSuccess,
"max_pool2d_with_indices_out_cuda_frame failed with error code ",
cudaGetLastError());
if(input.ndimension() == 3) {
output.resize_({nInputPlane, outputHeight, outputWidth});
}
}
void max_pool2d_with_indices_backward_out_cuda_template(
Tensor& gradInput,
const Tensor& gradOutput_,
const Tensor& input_,
const Tensor& indices,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
TensorArg gradInput_arg{ gradInput, "gradInput", 1 };
TensorArg gradOutput_arg{ gradOutput_, "gradOutput_", 2 };
TensorArg input_arg{ input_, "input_", 3 };
TensorArg indices_arg{ indices, "indices", 4 };
checkAllSameGPU("max_pool2d_with_indices_out_cuda",
{gradInput_arg, gradOutput_arg, input_arg, indices_arg});
// XXX JIT: Pooling.cpp allows stride.empty().
// XXX IntegrationTest.MNIST: padding.size() == 1 && dilation.size() == 1.
TORCH_CHECK(kernel_size.size() == 2 &&
(stride.empty() || stride.size() == 2) &&
(padding.size() == 1 || padding.size() == 2) &&
(dilation.size() == 1 || dilation.size() == 2),
"max_pool2d_with_indices: internal error: all IntArrayRef sizes must be 2");
TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input");
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = safe_downcast<int, int64_t>(kernel_size[1]);
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[1]);
const int padH = safe_downcast<int, int64_t>(padding[0]);
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
const Tensor input = input_.contiguous();
const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
const int64_t nInputPlane = input.size(-3);
const int64_t inputHeight = input.size(-2);
const int64_t inputWidth = input.size(-1);
const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
max_pool2d_with_indices_shape_check(
input_,
gradOutput_,
indices,
nbatch,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
nInputPlane,
inputHeight, inputWidth,
outputHeight, outputWidth,
/*cuda=*/ true);
const Tensor gradOutput = gradOutput_.contiguous();
gradInput.resize_as_(input);
int64_t count = input.numel();
dim3 grid;
int imgcount = inputWidth * inputHeight;
const int blocks = (imgcount + BACKWARD_THREADS - 1) / BACKWARD_THREADS;
grid.x = blocks;
grid.y = nbatch;
grid.z = nInputPlane;
uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
if (maxGridY < grid.y) grid.y = maxGridY;
if (maxGridZ < grid.z) grid.z = maxGridZ;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
"max_pool2d_with_indices_out_cuda_frame",
[&] {
using accscalar_t = acc_type<scalar_t, true>;
scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
scalar_t *gradInput_data = gradInput.data<scalar_t>();
int64_t *indices_data = indices.data<int64_t>();
MaxPoolBackward<scalar_t, accscalar_t>
<<<grid, BACKWARD_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
gradOutput_data,
indices_data,
nbatch,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
gradInput_data);
}
);
TORCH_CHECK(cudaGetLastError() == cudaSuccess,
"fractional_max_pool2d_backward_out_cuda failed with error code ",
cudaGetLastError());
}
} // namespace
std::tuple<Tensor& ,Tensor&> max_pool2d_with_indices_out_cuda(
Tensor& output,
Tensor& indices,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
max_pool2d_with_indices_out_cuda_template(
output,
indices,
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return std::tuple<Tensor&, Tensor&>(output, indices);
}
std::tuple<Tensor ,Tensor> max_pool2d_with_indices_cuda(
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
Tensor output = at::empty({0}, input.options());
Tensor indices = at::empty({0}, input.options().dtype(kLong));
max_pool2d_with_indices_out_cuda_template(
output,
indices,
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return std::tuple<Tensor&, Tensor&>(output, indices);
}
Tensor& max_pool2d_with_indices_backward_out_cuda(
Tensor& gradInput,
const Tensor& gradOutput_,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& indices)
{
max_pool2d_with_indices_backward_out_cuda_template(
gradInput,
gradOutput_,
input,
indices,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return gradInput;
}
Tensor max_pool2d_with_indices_backward_cuda(
const Tensor& gradOutput_,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& indices)
{
auto gradInput = at::zeros_like(input);
max_pool2d_with_indices_backward_out_cuda_template(
gradInput,
gradOutput_,
input,
indices,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return gradInput;
}
} // at::native
} // at

View File

@ -3792,18 +3792,30 @@
CUDA: fractional_max_pool3d_backward_cuda
# Return: (Tensor output, Tensor indices)
- func: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
- func: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
python_module: nn
dispatch:
CPU: max_pool2d_with_indices_out_cpu
CUDA: max_pool2d_with_indices_out_cuda
# Return: (Tensor output, Tensor indices)
- func: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
python_module: nn
dispatch:
CPU: max_pool2d_with_indices_cpu
CUDA: max_pool2d_with_indices_cuda
- func: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
python_module: nn
dispatch:
CPU: max_pool2d_with_indices_backward_out_cpu
CUDA: max_pool2d_with_indices_backward_out_cuda
- func: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor
python_module: nn
dispatch:
CPU: max_pool2d_with_indices_backward_cpu
CUDA: max_pool2d_with_indices_backward_cuda
# Return: (Tensor output, Tensor indices)
- func: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))

View File

@ -134,14 +134,6 @@
output: 'false'
grad_input: 'false'
- name: _thnn_max_pool2d_with_indices(Tensor self, IntArrayRef[2] kernel_size, IntArrayRef[2] stride={}, IntArrayRef[2] padding=0, IntArrayRef[2] dilation=1, bool ceil_mode=false)
cname: SpatialDilatedMaxPooling
default_init:
stride: kernel_size
scalar_check:
output: 'false'
grad_input: 'false'
- name: _thnn_max_pool3d_with_indices(Tensor self, IntArrayRef[3] kernel_size, IntArrayRef[3] stride={}, IntArrayRef[3] padding=0, IntArrayRef[3] dilation=1, bool ceil_mode=false)
cname: VolumetricDilatedMaxPooling
default_init:

View File

@ -33,10 +33,8 @@ ${CMAKE_CURRENT_SOURCE_DIR}/SpatialConvolutionMM.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialCrossMapLRN.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialDepthwiseConvolution.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialDilatedConvolution.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialDilatedMaxPooling.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialFullConvolution.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialFullDilatedConvolution.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialMaxPooling.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialMaxUnpooling.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialSubSampling.cu
${CMAKE_CURRENT_SOURCE_DIR}/Sqrt.cu

View File

@ -1,121 +0,0 @@
#include <THCUNN/THCUNN.h>
#include <THC/THCTensor.hpp>
#include <TH/THHalf.h>
#include <THCUNN/THCHalfAutoNumerics.cuh>
#include <THC/THCNumerics.cuh>
#include <THCUNN/common.h>
#include <c10/macros/Macros.h>
// kernels borrowed from Caffe
template <typename Dtype, typename AccType>
__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, Dtype* top_data,
int64_t* top_mask) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
while(hstart < 0)
hstart += dilation_h;
while(wstart < 0)
wstart += dilation_w;
AccType maxval = THCNumerics<AccType>::min();
int maxidx = -1;
bottom_data += (n * channels + c) * height * width;
for (int h = hstart; h < hend; h += dilation_h) {
for (int w = wstart; w < wend; w += dilation_w) {
Dtype val = bottom_data[h * width + w];
if ((ScalarConvert<Dtype, AccType>::to(val) > maxval) || THCNumerics<Dtype>::isnan(val)) {
maxidx = h * width + w;
maxval = ScalarConvert<Dtype, AccType>::to(val);
}
}
}
top_data[index] = ScalarConvert<AccType, Dtype>::to(maxval);
top_mask[index] = maxidx;
}
}
const int BACKWARD_THREADS = 256;
template <typename Dtype, typename AccType>
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2(BACKWARD_THREADS, 4)
#else
C10_LAUNCH_BOUNDS_2(BACKWARD_THREADS, 8)
#endif
__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
const int64_t* top_mask, const int num, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, height*width) {
int h = index/width;
int w = index - h * width;
//get some templating performance benefits without actually templating
int phstart, phend, pwstart, pwend;
if (stride_h == 1) {
phstart =
(h + pad_h < ((kernel_h - 1) * dilation_h + 1)) ? 0 : (h + pad_h - ((kernel_h - 1) * dilation_h + 1)) + 1;
phend = min((h + pad_h) + 1, pooled_height);
} else if (stride_h == 2) {
phstart =
(h + pad_h < ((kernel_h - 1) * dilation_h + 1)) ? 0 : (h + pad_h - ((kernel_h - 1) * dilation_h + 1)) / 2 + 1;
phend = min((h + pad_h) / 2 + 1, pooled_height);
} else {
phstart =
(h + pad_h < ((kernel_h - 1) * dilation_h + 1)) ? 0 : (h + pad_h - ((kernel_h - 1) * dilation_h + 1)) / stride_h + 1;
phend = min((h + pad_h) / stride_h + 1, pooled_height);
}
if (stride_w == 1) {
pwstart =
(w + pad_w < ((kernel_w - 1) * dilation_w + 1)) ? 0 : (w + pad_w - ((kernel_w - 1) * dilation_w + 1)) + 1;
pwend = min((w + pad_w) + 1, pooled_width);
} else if (stride_w == 2) {
pwstart =
(w + pad_w < ((kernel_w - 1) * dilation_w + 1)) ? 0 : (w + pad_w - ((kernel_w - 1) * dilation_w + 1)) / 2 + 1;
pwend = min((w + pad_w) / 2 + 1, pooled_width);
} else {
pwstart =
(w + pad_w < ((kernel_w - 1) * dilation_w + 1)) ? 0 : (w + pad_w - ((kernel_w - 1) * dilation_w + 1)) / stride_w + 1;
pwend = min((w + pad_w) / stride_w + 1, pooled_width);
}
for (int n = blockIdx.y; n < num; n += gridDim.y)
for (int c = blockIdx.z; c < channels; c+= gridDim.z) {
AccType gradient = AccType(0);
int offset = (n * channels + c) * pooled_height * pooled_width;
top_diff += offset;
top_mask += offset;
//get some templating performance benefits without actually templating
if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) {
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (top_mask[ph * pooled_width + pw] == h * width + w) {
gradient += ScalarConvert<Dtype, AccType>::to(top_diff[ph * pooled_width + pw]);
}
}
}
} else {
if (top_mask[phstart * pooled_width + pwstart] == h * width + w) {
gradient += ScalarConvert<Dtype, AccType>::to(top_diff[phstart * pooled_width + pwstart]);
}
}
bottom_diff[(n*channels+c)*height*width+index] = ScalarConvert<AccType, Dtype>::to(gradient);
}
}
}
#include <THCUNN/generic/SpatialDilatedMaxPooling.cu>
#include <THC/THCGenerateFloatTypes.h>

View File

@ -1,4 +0,0 @@
#include <THCUNN/THCUNN.h>
#include <THCUNN/generic/SpatialMaxPooling.cu>
#include <THC/THCGenerateFloatTypes.h>

View File

@ -1,199 +0,0 @@
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "THCUNN/generic/SpatialDilatedMaxPooling.cu"
#else
#include <THCUNN/common.h>
#include <THCUNN/generic/pooling_shape.h>
#include <ATen/cuda/CUDAContext.h>
static inline void THNN_(SpatialDilatedMaxPooling_shapeCheck)(
THCState *state,
THCTensor *input, THCTensor *gradOutput, THCIndexTensor *indices,
int kH, int kW, int dH, int dW, int padH, int padW,
int dilationH, int dilationW, bool ceil_mode) {
THArgCheck(kW > 0 && kH > 0, 5,
"kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW);
THArgCheck(dW > 0 && dH > 0, 8,
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
THArgCheck(dilationH > 0 && dilationW > 0, 12,
"dilation should be greater than zero, but got dilationH: %d dilationW: %d",
dilationH, dilationW);
int ndim = input->dim();
int dimf = 0;
int dimh = 1;
int dimw = 2;
int batchSize = 1;
if (ndim == 4) {
batchSize = input->size(0);
dimf++;
dimh++;
dimw++;
}
THCUNN_argCheck(state, !input->is_empty() && (ndim == 3 || ndim == 4), 2, input,
"non-empty 3D or 4D input tensor expected but got: %s");
THArgCheck(kW/2 >= padW && kH/2 >= padH, 2,
"pad should be smaller than half of kernel size, but got "
"padW = %d, padH = %d, kW = %d, kH = %d",
padW, padH, kW, kH);
int64_t nInputPlane = input->size(dimh-1);
int64_t nInputRows = input->size(dimh);
int64_t nInputCols = input->size(dimw);
int64_t nOutputPlane = nInputPlane;
int64_t nOutputRows = pooling_output_shape<int64_t>(nInputRows, kH, padH, dH, dilationH, ceil_mode);
int64_t nOutputCols = pooling_output_shape<int64_t>(nInputCols, kW, padW, dW, dilationW, ceil_mode);
if (nOutputCols < 1 || nOutputRows < 1)
THError("Given input size: (%dx%dx%d). "
"Calculated output size: (%dx%dx%d). Output size is too small",
nInputPlane,nInputRows,nInputCols,nInputPlane,nOutputRows,nOutputCols);
if (gradOutput != NULL) {
THCUNN_check_dim_size(state, gradOutput, ndim, dimf, nOutputPlane);
THCUNN_check_dim_size(state, gradOutput, ndim, dimh, nOutputRows);
THCUNN_check_dim_size(state, gradOutput, ndim, dimw, nOutputCols);
}
if (indices != NULL) {
THCUNN_check_dim_size_indices(state, indices, 4, 0, batchSize);
THCUNN_check_dim_size_indices(state, indices, 4, 1, nOutputPlane);
THCUNN_check_dim_size_indices(state, indices, 4, 2, nOutputRows);
THCUNN_check_dim_size_indices(state, indices, 4, 3, nOutputCols);
}
}
void THNN_(SpatialDilatedMaxPooling_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
THCIndexTensor *indices,
int kW, int kH,
int dW, int dH,
int padW, int padH,
int dilationW, int dilationH,
bool ceil_mode)
{
THCUNN_assertSameGPU(state, 3, input, output, indices);
THNN_(SpatialDilatedMaxPooling_shapeCheck)
(state, input, NULL, NULL, kH, kW, dH, dW,
padH, padW, dilationH, dilationW, ceil_mode);
int64_t nInputCols, nInputRows, nInputPlane, batchSize;
int64_t nOutputCols, nOutputRows;
if (input->dim() == 3) {
nInputCols = input->size(2);
nInputRows = input->size(1);
nInputPlane = input->size(0);
batchSize = 1;
}
else
{
nInputCols = input->size(3);
nInputRows = input->size(2);
nInputPlane = input->size(1);
batchSize = input->size(0);
}
nOutputCols = pooling_output_shape<int64_t>(nInputCols, kW, padW, dW, dilationW, ceil_mode);
nOutputRows = pooling_output_shape<int64_t>(nInputRows, kH, padH, dH, dilationH, ceil_mode);
input = THCTensor_(newContiguous)(state, input);
scalar_t* input_data = THCTensor_(data)(state, input);
THCTensor_(resize4d)(state, output, batchSize, nInputPlane, nOutputRows, nOutputCols);
THCUNN_resizeAs_indices(state, indices, output);
THCIndex_t* indices_data = THCIndexTensor_(data)(state, indices);
scalar_t* output_data = THCTensor_(data)(state, output);
int count = THCTensor_(nElement)(state, output);
MaxPoolForward<scalar_t, accreal> <<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
(count, input_data,
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
kH, kW, dH, dW, padH, padW, dilationH, dilationW, output_data, indices_data);
THCudaCheck(cudaGetLastError());
if(input->dim() == 3)
THCTensor_(resize3d)(state, output, nInputPlane, nOutputRows, nOutputCols);
THCTensor_(free)(state, input);
}
void THNN_(SpatialDilatedMaxPooling_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
THCIndexTensor *indices,
int kW, int kH,
int dW, int dH,
int padW, int padH,
int dilationW, int dilationH,
bool ceil_mode)
{
THCUNN_assertSameGPU(state, 4, input, gradOutput, indices, gradInput);
THNN_(SpatialDilatedMaxPooling_shapeCheck)
(state, input, gradOutput, indices, kH, kW, dH, dW,
padH, padW, dilationH, dilationW, ceil_mode);
input = THCTensor_(newContiguous)(state, input);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
int64_t nInputCols, nInputRows, nInputPlane, batchSize;
int64_t nOutputCols, nOutputRows;
if (THTensor_nDimensionLegacyAll(input) == 3) {
nInputCols = input->size(2);
nInputRows = input->size(1);
nInputPlane = input->size(0);
batchSize = 1;
}
else
{
nInputCols = input->size(3);
nInputRows = input->size(2);
nInputPlane = input->size(1);
batchSize = input->size(0);
}
nOutputCols = pooling_output_shape<int64_t>(nInputCols, kW, padW, dW, dilationW, ceil_mode);
nOutputRows = pooling_output_shape<int64_t>(nInputRows, kH, padH, dH, dilationH, ceil_mode);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
THCTensor_(resizeAs)(state, gradInput, input);
int count = THCTensor_(nElement)(state, input);
dim3 grid;
int imgcount = nInputCols * nInputRows;
const int blocks = (imgcount + BACKWARD_THREADS - 1) / BACKWARD_THREADS;
grid.x = blocks;
grid.y = batchSize;
grid.z = nInputPlane;
uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
if (maxGridY < grid.y) grid.y = maxGridY;
if (maxGridZ < grid.z) grid.z = maxGridZ;
MaxPoolBackward<scalar_t, accreal> <<< grid, BACKWARD_THREADS, 0, THCState_getCurrentStream(state) >>>
(count,
THCTensor_(data)(state, gradOutput),
THCIndexTensor_(data)(state, indices),
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
THCTensor_(data)(state, gradInput));
THCudaCheck(cudaGetLastError());
THCTensor_(free)(state, gradOutput);
// clean
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
}
#endif

View File

@ -1,40 +0,0 @@
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "THCUNN/generic/SpatialMaxPooling.cu"
#else
#include <THCUNN/common.h>
void THNN_(SpatialMaxPooling_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
THCIndexTensor *indices,
int kW, int kH,
int dW, int dH,
int padW, int padH,
bool ceil_mode)
{
THNN_(SpatialDilatedMaxPooling_updateOutput)(
state, input, output, indices,
kW, kH, dW, dH, padW, padH, 1, 1, ceil_mode);
}
void THNN_(SpatialMaxPooling_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
THCIndexTensor *indices,
int kW, int kH,
int dW, int dH,
int padW, int padH,
bool ceil_mode)
{
THNN_(SpatialDilatedMaxPooling_updateGradInput)(
state, input, gradOutput, gradInput, indices,
kW, kH, dW, dH, padW, padH, 1, 1, ceil_mode);
}
#endif

View File

@ -665,29 +665,6 @@ THC_API void THNN_(SpatialFullDilatedConvolution_accGradParameters)(
int adjW, int adjH,
accreal scale);
THC_API void THNN_(SpatialDilatedMaxPooling_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
THCIndexTensor *indices,
int kW, int kH,
int dW, int dH,
int padW, int padH,
int dilationW, int dilationH,
bool ceil_mode);
THC_API void THNN_(SpatialDilatedMaxPooling_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
THCIndexTensor *indices,
int kW, int kH,
int dW, int dH,
int padW, int padH,
int dilationW, int dilationH,
bool ceil_mode);
THC_API void THNN_(SpatialFullConvolution_updateOutput)(
THCState *state,
THCTensor *input,
@ -727,27 +704,6 @@ THC_API void THNN_(SpatialFullConvolution_accGradParameters)(
int adjW, int adjH,
accreal scale);
THC_API void THNN_(SpatialMaxPooling_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
THCIndexTensor *indices,
int kW, int kH,
int dW, int dH,
int padW, int padH,
bool ceil_mode);
THC_API void THNN_(SpatialMaxPooling_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
THCIndexTensor *indices,
int kW, int kH,
int dW, int dH,
int padW, int padH,
bool ceil_mode);
THC_API void THNN_(SpatialMaxUnpooling_updateOutput)(
THCState *state,
THCTensor *input,

View File

@ -1,368 +0,0 @@
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "THNN/generic/SpatialDilatedMaxPooling.c"
#else
#include <THNN/generic/pooling_shape.h>
#include <algorithm>
#include <ATen/Parallel.h>
static inline void THNN_(SpatialDilatedMaxPooling_shapeCheck)(
THTensor *input, THTensor *gradOutput, THIndexTensor *indices,
int kH, int kW, int dH, int dW, int padH, int padW,
int dilationH, int dilationW, bool ceil_mode) {
THArgCheck(kW > 0 && kH > 0, 5,
"kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW);
THArgCheck(dW > 0 && dH > 0, 8,
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
THArgCheck(dilationH > 0 && dilationW > 0, 12,
"dilation should be greater than zero, but got dilationH: %d dilationW: %d",
dilationH, dilationW);
int ndim = input->dim();
int dimf = 0;
int dimh = 1;
int dimw = 2;
if (ndim == 4) {
dimf++;
dimh++;
dimw++;
}
THNN_ARGCHECK(!input->is_empty() && (ndim == 3 || ndim == 4), 2, input,
"non-empty 3D or 4D input tensor expected but got: %s");
THArgCheck(kW/2 >= padW && kH/2 >= padH, 2,
"pad should be smaller than half of kernel size, but got "
"padW = %d, padH = %d, kW = %d, kH = %d",
padW, padH, kW, kH);
int64_t nInputPlane = input->size(dimh-1);
int64_t inputHeight = input->size(dimh);
int64_t inputWidth = input->size(dimw);
int64_t nOutputPlane = nInputPlane;
int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
if (outputWidth < 1 || outputHeight < 1)
THError("Given input size: (%dx%dx%d). "
"Calculated output size: (%dx%dx%d). Output size is too small",
nInputPlane,inputHeight,inputWidth,nInputPlane,outputHeight,outputWidth);
if (gradOutput != NULL) {
THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimf, nOutputPlane);
THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimh, outputHeight);
THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimw, outputWidth);
}
if (indices != NULL) {
THNN_CHECK_DIM_SIZE_INDICES(indices, ndim, dimf, nOutputPlane);
THNN_CHECK_DIM_SIZE_INDICES(indices, ndim, dimh, outputHeight);
THNN_CHECK_DIM_SIZE_INDICES(indices, ndim, dimw, outputWidth);
}
}
static void THNN_(SpatialDilatedMaxPooling_updateOutput_frame)(
scalar_t *input_p,
scalar_t *output_p,
THIndex_t *ind_p,
int64_t nslices,
int64_t iwidth,
int64_t iheight,
int64_t owidth,
int64_t oheight,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH
)
{
at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
for (auto k = start; k < end; k++)
{
/* loop over output */
int64_t i, j;
scalar_t *ip = input_p + k*iwidth*iheight;
for(i = 0; i < oheight; i++)
{
for(j = 0; j < owidth; j++)
{
int64_t hstart = i * dH - padH;
int64_t wstart = j * dW - padW;
int64_t hend = std::min(hstart + (kH - 1) * dilationH + 1, iheight);
int64_t wend = std::min(wstart + (kW - 1) * dilationW + 1, iwidth);
while(hstart < 0)
hstart += dilationH;
while(wstart < 0)
wstart += dilationW;
/* local pointers */
scalar_t *op = output_p + k*owidth*oheight + i*owidth + j;
THIndex_t *indp = ind_p + k*owidth*oheight + i*owidth + j;
/* compute local max: */
int64_t maxindex = -1;
scalar_t maxval = -THInf;
int64_t tcntr = 0;
int64_t x,y;
for(y = hstart; y < hend; y += dilationH)
{
for(x = wstart; x < wend; x += dilationW)
{
tcntr = y*iwidth + x;
scalar_t val = *(ip + tcntr);
if ((val > maxval) || std::isnan(val))
{
maxval = val;
maxindex = tcntr;
}
}
}
/* set output to local max */
*op = maxval;
/* store location of max */
*indp = maxindex;
}
}
}
});
}
void THNN_(SpatialDilatedMaxPooling_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
THIndexTensor *indices,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
bool ceil_mode)
{
int dimw = 2;
int dimh = 1;
int64_t nbatch = 1;
int64_t nInputPlane;
int64_t inputHeight;
int64_t inputWidth;
int64_t outputHeight;
int64_t outputWidth;
scalar_t *input_data;
scalar_t *output_data;
THIndex_t *indices_data;
THNN_(SpatialDilatedMaxPooling_shapeCheck)
(input, NULL, NULL, kH, kW, dH, dW,
padH, padW, dilationH, dilationW, ceil_mode);
if (input->dim() == 4)
{
nbatch = input->size(0);
dimw++;
dimh++;
}
/* sizes */
nInputPlane = input->size(dimh-1);
inputHeight = input->size(dimh);
inputWidth = input->size(dimw);
outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
/* get contiguous input */
input = THTensor_(newContiguous)(input);
/* resize output */
if (input->dim() == 3)
{
THTensor_(resize3d)(output, nInputPlane, outputHeight, outputWidth);
/* indices will contain the locations for each output point */
THIndexTensor_(resize3d)(indices, nInputPlane, outputHeight, outputWidth);
input_data = input->data<scalar_t>();
output_data = output->data<scalar_t>();
indices_data = THIndexTensor_(data)(indices);
THNN_(SpatialDilatedMaxPooling_updateOutput_frame)
(input_data, output_data,
indices_data,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
kW, kH, dW, dH,
padW, padH,
dilationW, dilationH
);
}
else
{
THTensor_(resize4d)(output, nbatch, nInputPlane, outputHeight, outputWidth);
/* indices will contain the locations for each output point */
THIndexTensor_(resize4d)(indices, nbatch, nInputPlane, outputHeight, outputWidth);
input_data = input->data<scalar_t>();
output_data = output->data<scalar_t>();
indices_data = THIndexTensor_(data)(indices);
at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
for (auto p = start; p < end; p++)
{
THNN_(SpatialDilatedMaxPooling_updateOutput_frame)
(input_data+p*nInputPlane*inputWidth*inputHeight,
output_data+p*nInputPlane*outputWidth*outputHeight,
indices_data+p*nInputPlane*outputWidth*outputHeight,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
kW, kH, dW, dH,
padW, padH,
dilationW, dilationH
);
}
});
}
/* cleanup */
c10::raw::intrusive_ptr::decref(input);
}
static void THNN_(SpatialDilatedMaxPooling_updateGradInput_frame)(
scalar_t *gradInput_p,
scalar_t *gradOutput_p,
THIndex_t *ind_p,
int64_t nInputPlane,
int64_t inputWidth,
int64_t inputHeight,
int64_t outputWidth,
int64_t outputHeight,
int dW,
int dH)
{
at::parallel_for(0, nInputPlane, 0, [&](int64_t start, int64_t end) {
for (auto k = start; k < end; k++)
{
scalar_t *gradInput_p_k = gradInput_p + k*inputWidth*inputHeight;
scalar_t *gradOutput_p_k = gradOutput_p + k*outputWidth*outputHeight;
THIndex_t *ind_p_k = ind_p + k*outputWidth*outputHeight;
/* calculate max points */
int64_t i, j;
for(i = 0; i < outputHeight; i++)
{
for(j = 0; j < outputWidth; j++)
{
/* retrieve position of max */
int64_t maxp = ind_p_k[i*outputWidth + j];
if (maxp != -1) {
/* update gradient */
gradInput_p_k[maxp] += gradOutput_p_k[i*outputWidth + j];
}
}
}
}
});
}
void THNN_(SpatialDilatedMaxPooling_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
THIndexTensor *indices,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
bool ceil_mode)
{
int dimw = 2;
int dimh = 1;
int64_t nbatch = 1;
int nInputPlane;
int inputHeight;
int inputWidth;
int outputHeight;
int outputWidth;
scalar_t *gradInput_data;
scalar_t *gradOutput_data;
THIndex_t *indices_data;
THNN_(SpatialDilatedMaxPooling_shapeCheck)
(input, gradOutput, indices, kH, kW, dH, dW,
padH, padW, dilationH, dilationW, ceil_mode);
/* get contiguous gradOutput */
gradOutput = THTensor_(newContiguous)(gradOutput);
/* resize */
THTensor_(resizeAs)(gradInput, input);
THTensor_(zero)(gradInput);
if (input->dim() == 4) {
nbatch = input->size(0);
dimw++;
dimh++;
}
/* sizes */
nInputPlane = input->size(dimh-1);
inputHeight = input->size(dimh);
inputWidth = input->size(dimw);
outputHeight = gradOutput->size(dimh);
outputWidth = gradOutput->size(dimw);
/* get raw pointers */
gradInput_data = gradInput->data<scalar_t>();
gradOutput_data = gradOutput->data<scalar_t>();
indices_data = THIndexTensor_(data)(indices);
/* backprop */
if (input->dim() == 3)
{
THNN_(SpatialDilatedMaxPooling_updateGradInput_frame)
(gradInput_data, gradOutput_data,
indices_data,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
dW, dH);
}
else
{
at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
for (auto p = start; p < end; p++)
{
THNN_(SpatialDilatedMaxPooling_updateGradInput_frame)
(gradInput_data+p*nInputPlane*inputWidth*inputHeight,
gradOutput_data+p*nInputPlane*outputWidth*outputHeight,
indices_data+p*nInputPlane*outputWidth*outputHeight,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
dW, dH);
}
});
}
/* cleanup */
c10::raw::intrusive_ptr::decref(gradOutput);
}
#endif

View File

@ -526,28 +526,6 @@ TH_API void THNN_(SpatialFullDilatedConvolution_accGradParameters)(
int adjW, int adjH,
accreal scale);
TH_API void THNN_(SpatialDilatedMaxPooling_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
THIndexTensor *indices,
int kW, int kH,
int dW, int dH,
int padW, int padH,
int dilationW, int dilationH,
bool ceil_mode);
TH_API void THNN_(SpatialDilatedMaxPooling_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
THIndexTensor *indices,
int kW, int kH,
int dW, int dH,
int padW, int padH,
int dilationW, int dilationH,
bool ceil_mode);
TH_API void THNN_(SpatialMaxUnpooling_updateOutput)(
THNNState *state,
THTensor *input,

View File

@ -145,9 +145,6 @@
#include <THNN/generic/SpatialAveragePooling.c>
#include <TH/THGenerateFloatTypes.h>
#include <THNN/generic/SpatialDilatedMaxPooling.c>
#include <TH/THGenerateFloatTypes.h>
#include <THNN/generic/SpatialMaxUnpooling.c>
#include <TH/THGenerateFloatTypes.h>

View File

@ -277,8 +277,6 @@ def _generate_function_classes(scope_dict):
'SpatialConvolutionMM',
'TemporalConvolution',
'SpatialAveragePooling',
'SpatialMaxPooling',
'SpatialDilatedMaxPooling',
'SpatialMaxUnpooling',
'VolumetricAveragePooling',
'VolumetricMaxPooling',