mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Port adaptive_avg_pool3d to ATen (#19898)
Summary: Resolves #18065. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19898 Differential Revision: D15240607 Pulled By: ezyang fbshipit-source-id: 00cf23ed20c1757d5eef71fd8c6a2f53d372e341
This commit is contained in:
committed by
Facebook Github Bot
parent
5268b7dfaf
commit
7799ea5eb3
312
aten/src/ATen/native/AdaptiveAveragePooling3d.cpp
Normal file
312
aten/src/ATen/native/AdaptiveAveragePooling3d.cpp
Normal file
@ -0,0 +1,312 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
namespace {
|
||||
|
||||
inline int start_index(int a, int b, int c) {
|
||||
return (int)std::floor((float)(a * c) / b);
|
||||
}
|
||||
|
||||
inline int end_index(int a, int b, int c) {
|
||||
return (int)std::ceil((float)((a + 1) * c) / b);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void adaptive_avg_pool3d_out_frame(
|
||||
scalar_t* input_p,
|
||||
scalar_t* output_p,
|
||||
int64_t sizeD,
|
||||
int64_t isizeT,
|
||||
int64_t isizeH,
|
||||
int64_t isizeW,
|
||||
int64_t osizeT,
|
||||
int64_t osizeH,
|
||||
int64_t osizeW,
|
||||
int64_t istrideD,
|
||||
int64_t istrideT,
|
||||
int64_t istrideH,
|
||||
int64_t istrideW) {
|
||||
int64_t d;
|
||||
#pragma omp parallel for private(d)
|
||||
for (d = 0; d < sizeD; d++) {
|
||||
/* loop over output */
|
||||
int64_t ot, oh, ow;
|
||||
for (ot = 0; ot < osizeT; ot++) {
|
||||
int istartT = start_index(ot, osizeT, isizeT);
|
||||
int iendT = end_index(ot, osizeT, isizeT);
|
||||
int kT = iendT - istartT;
|
||||
|
||||
for (oh = 0; oh < osizeH; oh++) {
|
||||
int istartH = start_index(oh, osizeH, isizeH);
|
||||
int iendH = end_index(oh, osizeH, isizeH);
|
||||
int kH = iendH - istartH;
|
||||
|
||||
for (ow = 0; ow < osizeW; ow++) {
|
||||
int istartW = start_index(ow, osizeW, isizeW);
|
||||
int iendW = end_index(ow, osizeW, isizeW);
|
||||
int kW = iendW - istartW;
|
||||
|
||||
/* local pointers */
|
||||
scalar_t* ip = input_p + d * istrideD + istartT * istrideT +
|
||||
istartH * istrideH + istartW * istrideW;
|
||||
scalar_t* op = output_p + d * osizeT * osizeH * osizeW +
|
||||
ot * osizeH * osizeW + oh * osizeW + ow;
|
||||
|
||||
/* compute local average: */
|
||||
scalar_t sum = 0;
|
||||
int it, ih, iw;
|
||||
for (it = 0; it < kT; it++) {
|
||||
for (ih = 0; ih < kH; ih++) {
|
||||
for (iw = 0; iw < kW; iw++) {
|
||||
scalar_t val =
|
||||
*(ip + it * istrideT + ih * istrideH + iw * istrideW);
|
||||
sum += val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* set output to local average */
|
||||
*op = sum / kT / kH / kW;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void adaptive_avg_pool3d_out_cpu_template(
|
||||
Tensor& output,
|
||||
Tensor const& input,
|
||||
IntArrayRef output_size) {
|
||||
for (int64_t i = 0; i < input.ndimension(); i++) {
|
||||
AT_CHECK(
|
||||
input.size(i) > 0,
|
||||
"adaptive_avg_pool3d(): expected input to have non-empty spatial dimensions, "
|
||||
"but input has sizes ",
|
||||
input.sizes(),
|
||||
" with dimension ",
|
||||
i,
|
||||
" being "
|
||||
"empty");
|
||||
}
|
||||
|
||||
AT_CHECK(
|
||||
(input.ndimension() == 4 || input.ndimension() == 5),
|
||||
"non-empty 4D or 5D (batch mode) tensor expected for input");
|
||||
|
||||
/* sizes */
|
||||
int64_t sizeD = input.size(-4);
|
||||
int64_t isizeT = input.size(-3);
|
||||
int64_t isizeH = input.size(-2);
|
||||
int64_t isizeW = input.size(-1);
|
||||
/* strides */
|
||||
int64_t istrideD = input.stride(-4);
|
||||
int64_t istrideT = input.stride(-3);
|
||||
int64_t istrideH = input.stride(-2);
|
||||
int64_t istrideW = input.stride(-1);
|
||||
/* output sizes */
|
||||
auto osizeT = output_size[0];
|
||||
auto osizeH = output_size[1];
|
||||
auto osizeW = output_size[2];
|
||||
|
||||
if (input.ndimension() == 4) {
|
||||
output.resize_({sizeD, osizeT, osizeH, osizeW});
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(), "adaptive_avg_pool3d_cpu", [&] {
|
||||
auto input_data = input.data<scalar_t>();
|
||||
auto output_data = output.data<scalar_t>();
|
||||
adaptive_avg_pool3d_out_frame<scalar_t>(
|
||||
input_data,
|
||||
output_data,
|
||||
sizeD,
|
||||
isizeT,
|
||||
isizeH,
|
||||
isizeW,
|
||||
osizeT,
|
||||
osizeH,
|
||||
osizeW,
|
||||
istrideD,
|
||||
istrideT,
|
||||
istrideH,
|
||||
istrideW);
|
||||
});
|
||||
} else {
|
||||
output.resize_({input.size(-5), sizeD, osizeT, osizeH, osizeW});
|
||||
int64_t b;
|
||||
#pragma omp parallel for private(b)
|
||||
for (b = 0; b < input.size(0); b++) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(), "adaptive_avg_pool3d_cpu", [&] {
|
||||
auto input_data = input.data<scalar_t>();
|
||||
auto output_data = output.data<scalar_t>();
|
||||
adaptive_avg_pool3d_out_frame<scalar_t>(
|
||||
input_data + b * input.stride(0),
|
||||
output_data + b * sizeD * osizeT * osizeH * osizeW,
|
||||
sizeD,
|
||||
isizeT,
|
||||
isizeH,
|
||||
isizeW,
|
||||
osizeT,
|
||||
osizeH,
|
||||
osizeW,
|
||||
istrideD,
|
||||
istrideT,
|
||||
istrideH,
|
||||
istrideW);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void adaptive_avg_pool3d_backward_out_frame(
|
||||
scalar_t* gradInput_p,
|
||||
scalar_t* gradOutput_p,
|
||||
int64_t sizeD,
|
||||
int64_t isizeT,
|
||||
int64_t isizeH,
|
||||
int64_t isizeW,
|
||||
int64_t osizeT,
|
||||
int64_t osizeH,
|
||||
int64_t osizeW) {
|
||||
int64_t d;
|
||||
#pragma omp parallel for private(d)
|
||||
for (d = 0; d < sizeD; d++) {
|
||||
scalar_t* gradInput_p_d = gradInput_p + d * isizeT * isizeW * isizeH;
|
||||
scalar_t* gradOutput_p_d = gradOutput_p + d * osizeT * osizeW * osizeH;
|
||||
|
||||
/* calculate average */
|
||||
int64_t ot, oh, ow;
|
||||
for (ot = 0; ot < osizeT; ot++) {
|
||||
int istartT = start_index(ot, osizeT, isizeT);
|
||||
int iendT = end_index(ot, osizeT, isizeT);
|
||||
int kT = iendT - istartT;
|
||||
|
||||
for (oh = 0; oh < osizeH; oh++) {
|
||||
int istartH = start_index(oh, osizeH, isizeH);
|
||||
int iendH = end_index(oh, osizeH, isizeH);
|
||||
int kH = iendH - istartH;
|
||||
|
||||
for (ow = 0; ow < osizeW; ow++) {
|
||||
int istartW = start_index(ow, osizeW, isizeW);
|
||||
int iendW = end_index(ow, osizeW, isizeW);
|
||||
int kW = iendW - istartW;
|
||||
|
||||
scalar_t grad_delta =
|
||||
gradOutput_p_d[ot * osizeH * osizeW + oh * osizeW + ow] / kT /
|
||||
kH / kW;
|
||||
|
||||
int it, ih, iw;
|
||||
for (it = istartT; it < iendT; it++) {
|
||||
for (ih = istartH; ih < iendH; ih++) {
|
||||
for (iw = istartW; iw < iendW; iw++) {
|
||||
/* update gradient */
|
||||
gradInput_p_d[it * isizeH * isizeW + ih * isizeW + iw] +=
|
||||
grad_delta;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor& adaptive_avg_pool3d_backward_out_cpu_template(
|
||||
Tensor& gradInput,
|
||||
const Tensor& gradOutput_,
|
||||
const Tensor& input) {
|
||||
/* get contiguous gradOutput */
|
||||
auto gradOutput = gradOutput_.contiguous();
|
||||
|
||||
/* sizes */
|
||||
int64_t sizeD = input.size(-4);
|
||||
int64_t isizeT = input.size(-3);
|
||||
int64_t isizeH = input.size(-2);
|
||||
int64_t isizeW = input.size(-1);
|
||||
int64_t osizeT = gradOutput.size(-3);
|
||||
int64_t osizeH = gradOutput.size(-2);
|
||||
int64_t osizeW = gradOutput.size(-1);
|
||||
|
||||
/* backprop */
|
||||
if (input.ndimension() == 4) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(), "adaptive_avg_pool3d_backward_cpu", [&] {
|
||||
/* get raw pointers */
|
||||
scalar_t* gradInput_data = gradInput.data<scalar_t>();
|
||||
scalar_t* gradOutput_data = gradOutput.data<scalar_t>();
|
||||
|
||||
adaptive_avg_pool3d_backward_out_frame<scalar_t>(
|
||||
gradInput_data,
|
||||
gradOutput_data,
|
||||
sizeD,
|
||||
isizeT,
|
||||
isizeH,
|
||||
isizeW,
|
||||
osizeT,
|
||||
osizeH,
|
||||
osizeW);
|
||||
});
|
||||
} else {
|
||||
int64_t b;
|
||||
#pragma omp parallel for private(b)
|
||||
for (b = 0; b < input.size(0); b++) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(), "adaptive_avg_pool3d_backward_cpu", [&] {
|
||||
/* get raw pointers */
|
||||
scalar_t* gradInput_data = gradInput.data<scalar_t>();
|
||||
scalar_t* gradOutput_data = gradOutput.data<scalar_t>();
|
||||
adaptive_avg_pool3d_backward_out_frame<scalar_t>(
|
||||
gradInput_data + b * sizeD * isizeT * isizeH * isizeW,
|
||||
gradOutput_data + b * sizeD * osizeT * osizeH * osizeW,
|
||||
sizeD,
|
||||
isizeT,
|
||||
isizeH,
|
||||
isizeW,
|
||||
osizeT,
|
||||
osizeH,
|
||||
osizeW);
|
||||
});
|
||||
}
|
||||
}
|
||||
return gradInput;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor& adaptive_avg_pool3d_out_cpu(
|
||||
Tensor& output,
|
||||
const Tensor& input,
|
||||
IntArrayRef output_size) {
|
||||
adaptive_avg_pool3d_out_cpu_template(output, input, output_size);
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor adaptive_avg_pool3d_cpu(Tensor const& input, IntArrayRef output_size) {
|
||||
auto output = at::empty({0}, input.options());
|
||||
adaptive_avg_pool3d_out_cpu_template(output, input, output_size);
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor& adaptive_avg_pool3d_backward_out_cpu(
|
||||
Tensor& gradInput,
|
||||
const Tensor& gradOutput_,
|
||||
const Tensor& input) {
|
||||
gradInput.resize_as_(input).zero_();
|
||||
adaptive_avg_pool3d_backward_out_cpu_template(gradInput, gradOutput_, input);
|
||||
return gradInput;
|
||||
}
|
||||
|
||||
Tensor adaptive_avg_pool3d_backward_cpu(
|
||||
const Tensor& gradOutput_,
|
||||
const Tensor& input) {
|
||||
auto gradInput = at::zeros_like(input);
|
||||
adaptive_avg_pool3d_backward_out_cpu_template(gradInput, gradOutput_, input);
|
||||
return gradInput;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -332,22 +332,6 @@ Tensor softshrink_backward(const Tensor & grad_output, const Tensor & self, Scal
|
||||
return at::legacy::th::_thnn_softshrink_backward(grad_output, self, lambd);
|
||||
}
|
||||
|
||||
Tensor & adaptive_avg_pool3d_out(Tensor & output, const Tensor & self, IntArrayRef output_size) {
|
||||
return at::legacy::th::_thnn_adaptive_avg_pool3d_forward_out(output, self, output_size);
|
||||
}
|
||||
|
||||
Tensor adaptive_avg_pool3d(const Tensor & self, IntArrayRef output_size) {
|
||||
return at::legacy::th::_thnn_adaptive_avg_pool3d_forward(self, output_size);
|
||||
}
|
||||
|
||||
Tensor & adaptive_avg_pool3d_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self) {
|
||||
return at::legacy::th::_thnn_adaptive_avg_pool3d_backward_out(grad_input, grad_output, self);
|
||||
}
|
||||
|
||||
Tensor adaptive_avg_pool3d_backward(const Tensor & grad_output, const Tensor & self) {
|
||||
return at::legacy::th::_thnn_adaptive_avg_pool3d_backward(grad_output, self);
|
||||
}
|
||||
|
||||
Tensor & avg_pool2d_out(Tensor & output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad) {
|
||||
return at::legacy::th::_thnn_avg_pool2d_forward_out(output, self, kernel_size, stride, padding, ceil_mode, count_include_pad);
|
||||
}
|
||||
|
517
aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu
Normal file
517
aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu
Normal file
@ -0,0 +1,517 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <THC/THCGeneral.h>
|
||||
#include <THC/THCNumerics.cuh>
|
||||
#include <THC/THCAtomics.cuh> // for atomicAdd
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
namespace {
|
||||
|
||||
__device__ inline int start_index(int a, int b, int c) {
|
||||
return (int)std::floor((float)(a * c) / b);
|
||||
}
|
||||
|
||||
__device__ inline int end_index(int a, int b, int c) {
|
||||
return (int)std::ceil((float)((a + 1) * c) / b);
|
||||
}
|
||||
|
||||
// 5d tensor B x D x T x H x W
|
||||
// All kernels view batch dim B and dim D as collapsed.
|
||||
|
||||
/*
|
||||
* Description:
|
||||
* this function adaptively average pools an input 5D tensor along dimensions
|
||||
* 2, 3, and 4 5D input, 5D output
|
||||
*
|
||||
* gridDim.y blocks work together on a single 2D output plane specified by
|
||||
* (blockIdx.x + offsetZ).
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
__global__ void adaptiveaveragepool(
|
||||
scalar_t *input, scalar_t *output,
|
||||
int isizeT, int isizeH, int isizeW,
|
||||
int osizeT, int osizeH, int osizeW,
|
||||
int64_t istrideD,
|
||||
int64_t istrideT, int64_t istrideH, int64_t istrideW,
|
||||
int64_t offsetZ) {
|
||||
// iterates on output pixels
|
||||
int ot, oh, ow;
|
||||
|
||||
// compute offsets based on thread/block ID
|
||||
int ostartH = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int oendH = osizeH;
|
||||
int ostepH = gridDim.y * blockDim.y;
|
||||
int ostartW = threadIdx.x;
|
||||
int oendW = osizeW;
|
||||
int ostepW = blockDim.x;
|
||||
|
||||
// select output plane
|
||||
int64_t o_plane = blockIdx.x + offsetZ;
|
||||
ot = o_plane % osizeT; // output frame/time
|
||||
int d = o_plane / osizeT; // slice/feature
|
||||
|
||||
// input frame/time range is fixed.
|
||||
int istartT = start_index(ot, osizeT, isizeT);
|
||||
int iendT = end_index(ot, osizeT, isizeT);
|
||||
int kT = iendT - istartT;
|
||||
|
||||
// input offset by slice/feature and earliest relevant frame/time
|
||||
scalar_t *input_dt = input + d*istrideD + istartT*istrideT;
|
||||
// output offset by slice/feature and frame/time
|
||||
scalar_t *output_dt = output + o_plane*osizeH*osizeW;
|
||||
|
||||
// For all output pixels...
|
||||
for (oh = ostartH; oh < oendH; oh += ostepH) {
|
||||
int istartH = start_index(oh, osizeH, isizeH);
|
||||
int iendH = end_index(oh, osizeH, isizeH);
|
||||
int kH = iendH - istartH;
|
||||
|
||||
for (ow = ostartW; ow < oendW; ow += ostepW) {
|
||||
int istartW = start_index(ow, osizeW, isizeW);
|
||||
int iendW = end_index(ow, osizeW, isizeW);
|
||||
int kW = iendW - istartW;
|
||||
|
||||
// Compute the average pooling from corresponding input pixels
|
||||
scalar_t *ptr_input = input_dt + istartH*istrideH + istartW*istrideW;
|
||||
scalar_t *ptr_output = output_dt + oh*osizeW + ow;
|
||||
scalar_t sum = ScalarConvert<int, scalar_t>::to(0);
|
||||
|
||||
int it, ih, iw;
|
||||
for (it = 0; it < kT; ++it) {
|
||||
for (ih = 0; ih < kH; ++ih) {
|
||||
for (iw = 0; iw < kW; ++iw) {
|
||||
scalar_t val = ptr_input[ih*istrideH + iw*istrideW];
|
||||
sum += val;
|
||||
}
|
||||
}
|
||||
ptr_input += istrideT; // next input frame
|
||||
}
|
||||
// Update output
|
||||
*ptr_output = sum / kT / kH / kW;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void adaptiveaveragepool_loop(
|
||||
scalar_t *input_data, scalar_t *output_data,
|
||||
int64_t totalZ,
|
||||
int isizeT, int isizeH, int isizeW,
|
||||
int osizeT, int osizeH, int osizeW,
|
||||
int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) {
|
||||
int64_t offsetZ = 0;
|
||||
dim3 threads(32, 8);
|
||||
// each H*W plane is processed by blocksH thread blocks
|
||||
int blocksH = std::max((int)(16L / totalZ), 1);
|
||||
while (totalZ > 0) {
|
||||
dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
|
||||
adaptiveaveragepool<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
input_data, output_data,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW,
|
||||
istrideD,
|
||||
istrideT, istrideH, istrideW,
|
||||
offsetZ);
|
||||
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Description:
|
||||
* This function computes the gradInput from gradOutput.
|
||||
*
|
||||
* gridDim.y blocks work together on a single 2D output plane specified by
|
||||
* (blockIdx.x + offsetZ).
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
__global__ void adaptiveaveragegradinput(
|
||||
scalar_t *gradInput, scalar_t *gradOutput,
|
||||
int isizeT, int isizeH, int isizeW,
|
||||
int osizeT, int osizeH, int osizeW,
|
||||
int64_t offsetZ)
|
||||
{
|
||||
// iterators on input pixels
|
||||
int it, ih, iw;
|
||||
|
||||
// compute offsets based on thread/block ID
|
||||
int istartH = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int iendH = isizeH;
|
||||
int istepH = gridDim.y * blockDim.y;
|
||||
int istartW = threadIdx.x;
|
||||
int iendW = isizeW;
|
||||
int istepW = blockDim.x;
|
||||
|
||||
// select input plane
|
||||
int64_t i_plane = blockIdx.x + offsetZ;
|
||||
it = i_plane % isizeT; // output frame/time
|
||||
int d = i_plane / isizeT; // slice/feature
|
||||
|
||||
// output frame/time range is fixed.
|
||||
int ostartT = start_index(it, isizeT, osizeT);
|
||||
int oendT = end_index(it, isizeT, osizeT);
|
||||
|
||||
// gradInput offset by slice/feature and frame/time.
|
||||
scalar_t *gradInput_dt = gradInput + i_plane*isizeH*isizeW;
|
||||
// gradOutput offset by slice/feature and earliest relevant frame/time
|
||||
scalar_t *gradOutput_dt = gradOutput + (d*osizeT + ostartT)*osizeH*osizeW;
|
||||
|
||||
// For all input pixels...
|
||||
for (ih = istartH; ih < iendH; ih += istepH) {
|
||||
int ostartH = start_index(ih, isizeH, osizeH);
|
||||
int oendH = end_index(ih, isizeH, osizeH);
|
||||
|
||||
for (iw = istartW; iw < iendW; iw += istepW) {
|
||||
int ostartW = start_index(iw, isizeW, osizeW);
|
||||
int oendW = end_index(iw, isizeW, osizeW);
|
||||
|
||||
// Compute the gradients from corresponding output pixels
|
||||
scalar_t *ptr_gradInput = gradInput_dt + ih*isizeW + iw;
|
||||
scalar_t *ptr_gradOutput = gradOutput_dt;
|
||||
|
||||
// for all relevant output pixels
|
||||
int ot, oh, ow;
|
||||
for (ot = ostartT; ot < oendT; ++ot) {
|
||||
int kT = end_index(ot, osizeT, isizeT) - start_index(ot, osizeT, isizeT);
|
||||
for (oh = ostartH; oh < oendH; ++oh) {
|
||||
int kH = end_index(oh, osizeH, isizeH) - start_index(oh, osizeH, isizeH);
|
||||
for (ow = ostartW; ow < oendW; ++ow) {
|
||||
int kW = end_index(ow, osizeW, isizeW) - start_index(ow, osizeW, isizeW);
|
||||
scalar_t grad_delta = ptr_gradOutput[oh*isizeW + ow] / kW / kH / kT;
|
||||
*ptr_gradInput += grad_delta;
|
||||
}
|
||||
}
|
||||
ptr_gradOutput += osizeH*osizeW; // next output frame
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void adaptiveaveragegradinput_loop(
|
||||
scalar_t *gradInput_data, scalar_t *gradOutput_data,
|
||||
int64_t totalZ,
|
||||
int isizeT, int isizeH, int isizeW,
|
||||
int osizeT, int osizeH, int osizeW) {
|
||||
int64_t offsetZ = 0;
|
||||
dim3 threads(32, 8);
|
||||
// each H*W plane is processed by blocksH thread blocks
|
||||
int blocksH = std::max((int)(16L / totalZ), 1);
|
||||
while (totalZ > 0) {
|
||||
dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
|
||||
adaptiveaveragegradinput<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
gradInput_data, gradOutput_data,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW,
|
||||
offsetZ);
|
||||
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Description:
|
||||
* This function computes the gradInput from gradOutput.
|
||||
*
|
||||
* gridDim.y blocks work together on a single 2D output plane specified by
|
||||
* (blockIdx.x + offsetZ).
|
||||
*
|
||||
* (uses atomic add)
|
||||
*
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
__global__ void atomicadaptiveaveragegradinput(
|
||||
scalar_t *gradInput, scalar_t *gradOutput,
|
||||
int isizeT, int isizeH, int isizeW,
|
||||
int osizeT, int osizeH, int osizeW,
|
||||
int64_t offsetZ)
|
||||
{
|
||||
// iterators on output pixels
|
||||
int ot, oh, ow;
|
||||
|
||||
// compute offsets based on thread/block ID
|
||||
int ostartH = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int oendH = osizeH;
|
||||
int ostepH = gridDim.y * blockDim.y;
|
||||
int ostartW = threadIdx.x;
|
||||
int oendW = osizeW;
|
||||
int ostepW = blockDim.x;
|
||||
|
||||
// select output plane
|
||||
int64_t o_plane = blockIdx.x + offsetZ;
|
||||
ot = o_plane % osizeT; // output frame/time
|
||||
int d = o_plane / osizeT; // output slice/feature
|
||||
|
||||
// input frame/time range is fixed.
|
||||
int istartT = start_index(ot, osizeT, isizeT);
|
||||
int iendT = end_index(ot, osizeT, isizeT);
|
||||
int kT = iendT - istartT;
|
||||
|
||||
// gradInput offset by slice/feature and earliest relevant frame/time
|
||||
scalar_t *gradInput_nt = gradInput + (d*isizeT + istartT)*isizeH*isizeW;
|
||||
// gradOutput offset by slice/feature and frame/time
|
||||
scalar_t *gradOutput_nt = gradOutput + o_plane*osizeH*osizeW;
|
||||
|
||||
// For all output pixels...
|
||||
for (oh = ostartH; oh < oendH; oh += ostepH) {
|
||||
int istartH = start_index(oh, osizeH, isizeH);
|
||||
int iendH = end_index(oh, osizeH, isizeH);
|
||||
int kH = iendH - istartH;
|
||||
|
||||
for (ow = ostartW; ow < oendW; ow += ostepW) {
|
||||
int istartW = start_index(ow, osizeW, isizeW);
|
||||
int iendW = end_index(ow, osizeW, isizeW);
|
||||
int kW = iendW - istartW;
|
||||
|
||||
// Compute the gradients from corresponding input pixels
|
||||
scalar_t *ptr_gradInput = gradInput_nt + istartH*isizeW + istartW;
|
||||
scalar_t *ptr_gradOutput = gradOutput_nt + oh*osizeW + ow;
|
||||
scalar_t grad_delta = *ptr_gradOutput / kT / kH / kW;
|
||||
|
||||
int it, ih, iw;
|
||||
for (it = 0; it < kT; ++it) {
|
||||
for (ih = 0; ih < kH; ++ih) {
|
||||
for (iw = 0; iw < kW; ++iw) {
|
||||
atomicAdd(&(ptr_gradInput[ih*isizeW + iw]), grad_delta);
|
||||
}
|
||||
}
|
||||
ptr_gradInput += isizeH*isizeW; // next input frame
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void atomicadaptiveaveragegradinput_loop(
|
||||
scalar_t* gradInput_data, scalar_t* gradOutput_data,
|
||||
int64_t totalZ,
|
||||
int isizeT, int isizeH, int isizeW,
|
||||
int osizeT, int osizeH, int osizeW) {
|
||||
int64_t offsetZ = 0;
|
||||
dim3 threads(32, 8);
|
||||
int blocksH = std::max((int)(16L / totalZ), 1);
|
||||
while (totalZ > 0) {
|
||||
dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
|
||||
atomicadaptiveaveragegradinput<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
gradInput_data, gradOutput_data,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW,
|
||||
offsetZ);
|
||||
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
// 5D tensor B x D x T x H x w
|
||||
|
||||
void adaptive_avg_pool3d_out_cuda_template(
|
||||
Tensor& output,
|
||||
const Tensor& input_,
|
||||
IntArrayRef& output_size) {
|
||||
TensorArg output_arg{output, "output", 1};
|
||||
TensorArg input_arg{input_, "input_", 2};
|
||||
|
||||
checkAllSameGPU("adaptive_avg_pool3d_cuda", {output_arg, input_arg});
|
||||
|
||||
for (int64_t i = 0; i < input_.ndimension(); i++) {
|
||||
AT_CHECK(
|
||||
input_.size(i) > 0,
|
||||
"adaptive_avg_pool3d_cuda(): expected input to have non-empty spatial dimensions, "
|
||||
"but input has sizes ", input_.sizes(),
|
||||
" with dimension ", i, " being empty");
|
||||
}
|
||||
|
||||
AT_CHECK(
|
||||
(input_.ndimension() == 4 || input_.ndimension() == 5),
|
||||
"non-empty 4D or 5D (batch mode) tensor expected for input");
|
||||
|
||||
// the jit sometimes passes output_size.size() == 1
|
||||
AT_CHECK(
|
||||
output_size.size() == 1 || output_size.size() == 3,
|
||||
"adaptive_avg_pool3d: internal error: output_size.size() must be 1 or 3");
|
||||
|
||||
int64_t osizeT = output_size[0];
|
||||
int64_t osizeH = output_size[1];
|
||||
int64_t osizeW = output_size[2];
|
||||
|
||||
int64_t sizeD, isizeT, isizeH, isizeW;
|
||||
int64_t istrideD, istrideT, istrideH, istrideW;
|
||||
int64_t totalZ;
|
||||
|
||||
const Tensor& input = input_.ndimension() == 4 ? input_ : input_.contiguous();
|
||||
|
||||
if (input.ndimension() == 4) {
|
||||
sizeD = input.size(0);
|
||||
isizeT = input.size(1);
|
||||
isizeH = input.size(2);
|
||||
isizeW = input.size(3);
|
||||
|
||||
istrideD = input.stride(0);
|
||||
istrideT = input.stride(1);
|
||||
istrideH = input.stride(2);
|
||||
istrideW = input.stride(3);
|
||||
|
||||
output.resize_({sizeD, osizeT, osizeH, osizeW});
|
||||
|
||||
totalZ = sizeD * osizeT;
|
||||
} else {
|
||||
int64_t sizeB = input.size(0);
|
||||
sizeD = input.size(1);
|
||||
isizeT = input.size(2);
|
||||
isizeH = input.size(3);
|
||||
isizeW = input.size(4);
|
||||
|
||||
istrideD = input.stride(1);
|
||||
istrideT = input.stride(2);
|
||||
istrideH = input.stride(3);
|
||||
istrideW = input.stride(4);
|
||||
|
||||
output.resize_({sizeB, sizeD, osizeT, osizeH, osizeW});
|
||||
|
||||
totalZ = sizeB * sizeD * osizeT;
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(), "adaptive_avg_pool3d_cuda", [&] {
|
||||
scalar_t* input_data = input.data<scalar_t>();
|
||||
scalar_t* output_data = output.data<scalar_t>();
|
||||
|
||||
adaptiveaveragepool_loop(
|
||||
input_data, output_data,
|
||||
totalZ,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW,
|
||||
istrideD, istrideT, istrideH, istrideW);
|
||||
});
|
||||
}
|
||||
|
||||
void adaptive_avg_pool3d_backward_out_cuda_template(
|
||||
Tensor& gradInput,
|
||||
const Tensor& gradOutput_,
|
||||
const Tensor& input) {
|
||||
TensorArg grad_input_arg{gradInput, "gradInput", 1};
|
||||
TensorArg grad_output_arg{gradOutput_, "gradOutput_", 2};
|
||||
TensorArg input_arg{input, "input", 3};
|
||||
|
||||
checkAllSameGPU(
|
||||
"adaptive_avg_pool3d_out_cuda",
|
||||
{grad_input_arg, grad_output_arg, input_arg});
|
||||
|
||||
const Tensor gradOutput = gradOutput_.contiguous();
|
||||
|
||||
gradInput.resize_as_(input);
|
||||
gradInput.zero_();
|
||||
|
||||
int64_t sizeD, isizeT, isizeH, isizeW;
|
||||
int64_t osizeT, osizeH, osizeW;
|
||||
int64_t totalZ;
|
||||
|
||||
if (input.ndimension() == 4) {
|
||||
sizeD = input.size(0);
|
||||
isizeT = input.size(1);
|
||||
isizeH = input.size(2);
|
||||
isizeW = input.size(3);
|
||||
|
||||
osizeT = gradOutput.size(1);
|
||||
osizeH = gradOutput.size(2);
|
||||
osizeW = gradOutput.size(3);
|
||||
} else {
|
||||
sizeD = input.size(1);
|
||||
isizeT = input.size(2);
|
||||
isizeH = input.size(3);
|
||||
isizeW = input.size(4);
|
||||
|
||||
osizeT = gradOutput.size(2);
|
||||
osizeH = gradOutput.size(3);
|
||||
osizeW = gradOutput.size(4);
|
||||
}
|
||||
|
||||
bool atomic = (isizeW%osizeW != 0) || (isizeH%osizeH != 0) || (isizeT%osizeT != 0);
|
||||
|
||||
if (input.ndimension() == 4) {
|
||||
totalZ = atomic ? sizeD * osizeT : sizeD * isizeT;
|
||||
} else {
|
||||
int sizeB = input.size(0);
|
||||
totalZ = atomic ? sizeB * sizeD * osizeT : sizeB * sizeD * isizeT;
|
||||
}
|
||||
|
||||
if (atomic) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(), "adaptive_avg_pool3d_backward_cuda", [&] {
|
||||
scalar_t* gradInput_data = gradInput.data<scalar_t>();
|
||||
scalar_t* gradOutput_data = gradOutput.data<scalar_t>();
|
||||
|
||||
atomicadaptiveaveragegradinput_loop(
|
||||
gradInput_data, gradOutput_data,
|
||||
totalZ,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(), "adaptive_avg_pool3d_backward_cuda", [&] {
|
||||
scalar_t* gradInput_data = gradInput.data<scalar_t>();
|
||||
scalar_t* gradOutput_data = gradOutput.data<scalar_t>();
|
||||
|
||||
adaptiveaveragegradinput_loop(
|
||||
gradInput_data, gradOutput_data,
|
||||
totalZ,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor& adaptive_avg_pool3d_out_cuda(
|
||||
Tensor& output,
|
||||
const Tensor& input,
|
||||
IntArrayRef output_size) {
|
||||
adaptive_avg_pool3d_out_cuda_template(output, input, output_size);
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor adaptive_avg_pool3d_cuda(
|
||||
const Tensor& input,
|
||||
IntArrayRef output_size) {
|
||||
auto output = at::empty({0}, input.options());
|
||||
adaptive_avg_pool3d_out_cuda_template(output, input, output_size);
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor& adaptive_avg_pool3d_backward_out_cuda(
|
||||
Tensor& gradInput,
|
||||
const Tensor& gradOutput_,
|
||||
const Tensor& input) {
|
||||
adaptive_avg_pool3d_backward_out_cuda_template(gradInput, gradOutput_, input);
|
||||
return gradInput;
|
||||
}
|
||||
|
||||
Tensor adaptive_avg_pool3d_backward_cuda(
|
||||
const Tensor& gradOutput_,
|
||||
const Tensor& input) {
|
||||
auto gradInput = at::zeros_like(input);
|
||||
adaptive_avg_pool3d_backward_out_cuda_template(gradInput, gradOutput_, input);
|
||||
return gradInput;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -3665,15 +3665,27 @@
|
||||
|
||||
- func: adaptive_avg_pool3d(Tensor self, int[3] output_size, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: adaptive_avg_pool3d_out_cpu
|
||||
CUDA: adaptive_avg_pool3d_out_cuda
|
||||
|
||||
- func: adaptive_avg_pool3d(Tensor self, int[3] output_size) -> Tensor
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: adaptive_avg_pool3d_cpu
|
||||
CUDA: adaptive_avg_pool3d_cuda
|
||||
|
||||
- func: adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: adaptive_avg_pool3d_backward_out_cpu
|
||||
CUDA: adaptive_avg_pool3d_backward_out_cuda
|
||||
|
||||
- func: adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: adaptive_avg_pool3d_backward_cpu
|
||||
CUDA: adaptive_avg_pool3d_backward_cuda
|
||||
|
||||
# Return: (Tensor output, Tensor indices)
|
||||
- func: adaptive_max_pool2d(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
|
||||
|
@ -118,12 +118,6 @@
|
||||
|
||||
# Pooling
|
||||
|
||||
- name: _thnn_adaptive_avg_pool3d(Tensor self, IntArrayRef[3] output_size)
|
||||
cname: VolumetricAdaptiveAveragePooling
|
||||
scalar_check:
|
||||
output: 'false'
|
||||
grad_input: 'false'
|
||||
|
||||
- name: _thnn_avg_pool2d(Tensor self, IntArrayRef[2] kernel_size, IntArrayRef[2] stride={}, IntArrayRef[2] padding=0, bool ceil_mode=false, bool count_include_pad=true)
|
||||
cname: SpatialAveragePooling
|
||||
default_init:
|
||||
|
@ -50,7 +50,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/TemporalMaxPooling.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/TemporalRowConvolution.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/TemporalUpSamplingLinear.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/TemporalUpSamplingNearest.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/VolumetricAdaptiveAveragePooling.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/VolumetricAveragePooling.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/VolumetricConvolution.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/VolumetricDilatedConvolution.cu
|
||||
|
@ -1,248 +0,0 @@
|
||||
#include <THCUNN/THCUNN.h>
|
||||
#include <THC/THCTensor.hpp>
|
||||
#include <TH/THHalf.h>
|
||||
#include <THCUNN/THCHalfAutoNumerics.cuh>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
|
||||
#define START_IND(a,b,c) (int)floor((float)(a * c) / b)
|
||||
#define END_IND(a,b,c) (int)ceil((float)((a + 1) * c) / b)
|
||||
// #define START_IND(a,b,c) a * c / b
|
||||
// #define END_IND(a,b,c) (a + 1) * c / b + ((a + 1) * c % b > 0)?1:0
|
||||
|
||||
|
||||
#define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit
|
||||
|
||||
// 5d tensor B x D x T x H x W
|
||||
// All kernels view batch dim B and feature dim D as collapsed.
|
||||
|
||||
/*
|
||||
* Description:
|
||||
* This function adaptively average pools an input 5D tensor along dimensions
|
||||
* 2, 3 and 4.
|
||||
*
|
||||
* gridDim.y blocks work together on a single 2D output plane specified by
|
||||
* (blockIdx.x + offsetZ).
|
||||
*/
|
||||
template <typename T>
|
||||
__global__ void cunn_VolumetricAdaptiveAveragePooling_updateOutput_kernel(
|
||||
T *input, T *output,
|
||||
int isizeT, int isizeH, int isizeW,
|
||||
int osizeT, int osizeH, int osizeW,
|
||||
int64_t istrideD,
|
||||
int64_t istrideT, int64_t istrideH, int64_t istrideW,
|
||||
int64_t offsetZ)
|
||||
{
|
||||
// iterators on output pixels
|
||||
int ot, oh, ow;
|
||||
|
||||
// compute offsets based on thread/block ID
|
||||
int ostartH = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int oendH = osizeH;
|
||||
int ostepH = gridDim.y * blockDim.y;
|
||||
int ostartW = threadIdx.x;
|
||||
int oendW = osizeW;
|
||||
int ostepW = blockDim.x;
|
||||
|
||||
// select output plane
|
||||
int64_t o_plane = blockIdx.x + offsetZ;
|
||||
ot = o_plane % osizeT; // output frame/time
|
||||
int d = o_plane / osizeT; // slice/feature
|
||||
|
||||
// input frame/time ramge is fixed.
|
||||
int istartT = START_IND(ot, osizeT, isizeT);
|
||||
int iendT = END_IND(ot, osizeT, isizeT);
|
||||
int kT = iendT - istartT;
|
||||
|
||||
// input offset by slice/feature and earliest relevant frame/time
|
||||
T *input_dt = input + d*istrideD + istartT*istrideT;
|
||||
// output offset by slice/feature and frame/time
|
||||
T *output_dt = output + o_plane*osizeH*osizeW;
|
||||
|
||||
// For all output pixels...
|
||||
for(oh = ostartH; oh < oendH; oh += ostepH) {
|
||||
|
||||
int istartH = START_IND(oh, osizeH, isizeH);
|
||||
int iendH = END_IND(oh, osizeH, isizeH);
|
||||
int kH = iendH - istartH;
|
||||
|
||||
for(ow = ostartW; ow < oendW; ow += ostepW) {
|
||||
|
||||
int istartW = START_IND(ow, osizeW, isizeW);
|
||||
int iendW = END_IND(ow, osizeW, isizeW);
|
||||
int kW = iendW - istartW;
|
||||
|
||||
// Compute the average pooling from corresponding input pixels
|
||||
T *ptr_input = input_dt + istartH*istrideH + istartW*istrideW;
|
||||
T *ptr_output = output_dt + oh*osizeW + ow;
|
||||
T sum = ScalarConvert<int, T>::to(0);
|
||||
|
||||
int it, ih, iw;
|
||||
for(it = 0; it < kT; ++it) {
|
||||
for(ih = 0; ih < kH; ++ih) {
|
||||
for(iw = 0; iw < kW; ++iw) {
|
||||
T val = ptr_input[ih*istrideH + iw*istrideW];
|
||||
sum += val;
|
||||
}
|
||||
}
|
||||
ptr_input += istrideT; // next input frame
|
||||
}
|
||||
// Update output
|
||||
*ptr_output = sum / kT / kH / kW;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Description:
|
||||
* This function computes the gradInput from gradOutput.
|
||||
*
|
||||
* gridDim.y blocks work together on a single 2D input plane specified by
|
||||
* (blockIdx.x + offsetZ).
|
||||
*/
|
||||
template <typename T>
|
||||
__global__ void cunn_VolumetricAdaptiveAveragePooling_updateGradInput_kernel(
|
||||
T *gradInput, T *gradOutput,
|
||||
int isizeT, int isizeH, int isizeW,
|
||||
int osizeT, int osizeH, int osizeW,
|
||||
int64_t offsetZ
|
||||
)
|
||||
{
|
||||
// iterators on input pixels
|
||||
int it, ih, iw;
|
||||
|
||||
// compute offsets based on thread/block ID
|
||||
int istartH = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int iendH = isizeH;
|
||||
int istepH = gridDim.y * blockDim.y;
|
||||
int istartW = threadIdx.x;
|
||||
int iendW = isizeW;
|
||||
int istepW = blockDim.x;
|
||||
|
||||
// select input plane
|
||||
int64_t i_plane = blockIdx.x + offsetZ;
|
||||
it = i_plane % isizeT; // output frame/time
|
||||
int d = i_plane / isizeT; // slice/feature
|
||||
|
||||
// output frame/time ramge is fixed.
|
||||
int ostartT = START_IND(it, isizeT, osizeT);
|
||||
int oendT = END_IND(it, isizeT, osizeT);
|
||||
|
||||
// gradInput offset by slice/feature and frame/time
|
||||
T *gradInput_dt = gradInput + i_plane*isizeH*isizeW;
|
||||
// gradOutput offset by slice/feature and earliest relevant frame/time
|
||||
T *gradOutput_dt = gradOutput + (d*osizeT + ostartT)*osizeH*osizeW;
|
||||
|
||||
// For all input pixels...
|
||||
for(ih = istartH; ih < iendH; ih += istepH) {
|
||||
|
||||
int ostartH = START_IND(ih, isizeH, osizeH);
|
||||
int oendH = END_IND(ih, isizeH, osizeH);
|
||||
|
||||
for(iw = istartW; iw < iendW; iw += istepW) {
|
||||
|
||||
int ostartW = START_IND(iw, isizeW, osizeW);
|
||||
int oendW = END_IND(iw, isizeW, osizeW);
|
||||
|
||||
// Compute the gradients from corresponding output pixels
|
||||
T *ptr_gradInput = gradInput_dt + ih*isizeW + iw;
|
||||
T *ptr_gradOutput = gradOutput_dt;
|
||||
|
||||
// for all relevant output pixels
|
||||
int ot, oh, ow;
|
||||
for(ot = ostartT; ot < oendT; ++ot) {
|
||||
int kT = END_IND(ot, osizeT, isizeT) - START_IND(ot, osizeT, isizeT);
|
||||
for(oh = ostartH; oh < oendH; ++oh) {
|
||||
int kH = END_IND(oh, osizeH, isizeH) - START_IND(oh, osizeH, isizeH);
|
||||
for(ow = ostartW; ow < oendW; ++ow) {
|
||||
int kW = END_IND(ow, osizeW, isizeW) - START_IND(ow, osizeW, isizeW);
|
||||
T grad_delta = ptr_gradOutput[oh*osizeW + ow] / kW / kH / kT;
|
||||
*ptr_gradInput += grad_delta;
|
||||
}
|
||||
}
|
||||
ptr_gradOutput += osizeH*osizeW; // next output frame
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Description:
|
||||
* This function computes the gradInput from gradOutput without assuming
|
||||
* dependencies between input pixels and output pixels.
|
||||
*
|
||||
* gridDim.y blocks work together on a single 2D output plane specified by
|
||||
* (blockIdx.x + offsetZ).
|
||||
*
|
||||
* (uses atomic add)
|
||||
*/
|
||||
template <typename T>
|
||||
__global__ void cunn_atomic_VolumetricAdaptiveAveragePooling_updateGradInput_kernel(
|
||||
T *gradInput, T *gradOutput,
|
||||
int isizeT, int isizeH, int isizeW,
|
||||
int osizeT, int osizeH, int osizeW,
|
||||
int64_t offsetZ
|
||||
)
|
||||
{
|
||||
// iterators on output pixels
|
||||
int ot, oh, ow;
|
||||
|
||||
// compute offsets based on thread/block ID
|
||||
int ostartH = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int oendH = osizeH;
|
||||
int ostepH = gridDim.y * blockDim.y;
|
||||
int ostartW = threadIdx.x;
|
||||
int oendW = osizeW;
|
||||
int ostepW = blockDim.x;
|
||||
|
||||
// select output plane
|
||||
int64_t o_plane = blockIdx.x + offsetZ;
|
||||
ot = o_plane % osizeT; // output frame/time
|
||||
int d = o_plane / osizeT; // output slice/feature
|
||||
|
||||
// input frame/time ramge is fixed.
|
||||
int istartT = START_IND(ot, osizeT, isizeT);
|
||||
int iendT = END_IND(ot, osizeT, isizeT);
|
||||
int kT = iendT - istartT;
|
||||
|
||||
// gradInput offset by slice/feature and earliest relevant frame/time
|
||||
T *gradInput_nt = gradInput + (d*isizeT + istartT)*isizeH*isizeW;
|
||||
// gradOutput offset by slice/feature and frame/time
|
||||
T *gradOutput_nt = gradOutput + o_plane*osizeH*osizeW;
|
||||
|
||||
// For all output pixels...
|
||||
for(oh = ostartH; oh < oendH; oh += ostepH) {
|
||||
|
||||
int istartH = START_IND(oh, osizeH, isizeH);
|
||||
int iendH = END_IND(oh, osizeH, isizeH);
|
||||
int kH = iendH - istartH;
|
||||
|
||||
for(ow = ostartW; ow < oendW; ow += ostepW) {
|
||||
|
||||
int istartW = START_IND(ow, osizeW, isizeW);
|
||||
int iendW = END_IND(ow, osizeW, isizeW);
|
||||
int kW = iendW - istartW;
|
||||
|
||||
// Compute the gradients from corresponding input pixels
|
||||
T *ptr_gradInput = gradInput_nt + istartH*isizeW + istartW;
|
||||
T *ptr_gradOutput = gradOutput_nt + oh*osizeW + ow;
|
||||
T grad_delta = *ptr_gradOutput / kT / kH / kW;
|
||||
|
||||
int it, ih, iw;
|
||||
for(it = 0; it < kT; ++it) {
|
||||
for(ih = 0; ih < kH; ++ih) {
|
||||
for(iw = 0; iw < kW; ++iw) {
|
||||
atomicAdd(&(ptr_gradInput[ih*isizeW + iw]), grad_delta);
|
||||
}
|
||||
}
|
||||
ptr_gradInput += isizeH*isizeW; // next input frame
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#include <THCUNN/generic/VolumetricAdaptiveAveragePooling.cu>
|
||||
#include <THC/THCGenerateFloatTypes.h>
|
||||
|
||||
#undef CUDA_MAX_THREADS
|
||||
#undef START_IND
|
||||
#undef END_IND
|
@ -1315,20 +1315,6 @@ THC_API void THNN_(VolumetricMaxUnpooling_updateGradInput)(
|
||||
int dT, int dW, int dH,
|
||||
int padT, int padW, int padH);
|
||||
|
||||
THC_API void THNN_(VolumetricAdaptiveAveragePooling_updateOutput)(
|
||||
THCState *state,
|
||||
THCTensor *input,
|
||||
THCTensor *output,
|
||||
int osizeT,
|
||||
int osizeW,
|
||||
int osizeH);
|
||||
|
||||
THC_API void THNN_(VolumetricAdaptiveAveragePooling_updateGradInput)(
|
||||
THCState *state,
|
||||
THCTensor *input,
|
||||
THCTensor *gradOutput,
|
||||
THCTensor *gradInput);
|
||||
|
||||
THC_API void THNN_(VolumetricUpSamplingNearest_updateGradInput)(
|
||||
THCState *state,
|
||||
THCTensor *gradOutput,
|
||||
|
@ -1,173 +0,0 @@
|
||||
#ifndef THC_GENERIC_FILE
|
||||
#define THC_GENERIC_FILE "THCUNN/generic/VolumetricAdaptiveAveragePooling.cu"
|
||||
#else
|
||||
|
||||
#include <THCUNN/common.h>
|
||||
|
||||
// 5d tensor B x D x T x H x W
|
||||
|
||||
void THNN_(VolumetricAdaptiveAveragePooling_updateOutput)(
|
||||
THCState *state,
|
||||
THCTensor *input,
|
||||
THCTensor *output,
|
||||
int osizeT,
|
||||
int osizeW,
|
||||
int osizeH)
|
||||
{
|
||||
THCUNN_assertSameGPU(state, 2, input, output);
|
||||
|
||||
THCUNN_argCheck(state, !input->is_empty() && (input->dim() == 4 || input->dim() == 5), 2, input,
|
||||
"non-empty 4D or 5D (batch mode) tensor expected for input, but got: %s");
|
||||
|
||||
|
||||
scalar_t *output_data;
|
||||
scalar_t *input_data;
|
||||
|
||||
int64_t sizeD, isizeT, isizeH, isizeW;
|
||||
int64_t istrideD, istrideT, istrideH, istrideW;
|
||||
int64_t totalZ;
|
||||
|
||||
if (input->dim() == 4) {
|
||||
sizeD = input->size(0);
|
||||
isizeT = input->size(1);
|
||||
isizeH = input->size(2);
|
||||
isizeW = input->size(3);
|
||||
|
||||
istrideD = input->stride(0);
|
||||
istrideT = input->stride(1);
|
||||
istrideH = input->stride(2);
|
||||
istrideW = input->stride(3);
|
||||
|
||||
THCTensor_(resize4d)(state, output, sizeD, osizeT, osizeH, osizeW);
|
||||
|
||||
totalZ = sizeD * osizeT;
|
||||
} else {
|
||||
input = THCTensor_(newContiguous)(state, input);
|
||||
|
||||
int64_t sizeB = input->size(0);
|
||||
sizeD = input->size(1);
|
||||
isizeT = input->size(2);
|
||||
isizeH = input->size(3);
|
||||
isizeW = input->size(4);
|
||||
|
||||
istrideD = input->stride(1);
|
||||
istrideT = input->stride(2);
|
||||
istrideH = input->stride(3);
|
||||
istrideW = input->stride(4);
|
||||
|
||||
THCTensor_(resize5d)(state, output, sizeB, sizeD, osizeT, osizeH, osizeW);
|
||||
|
||||
totalZ = sizeB * sizeD * osizeT;
|
||||
}
|
||||
|
||||
input_data = THCTensor_(data)(state, input);
|
||||
output_data = THCTensor_(data)(state, output);
|
||||
|
||||
int64_t offsetZ = 0;
|
||||
dim3 threads(32, 8);
|
||||
// each H*W plane is processed by blocksH thread blocks
|
||||
int blocksH = max((int)(16L / totalZ), 1);
|
||||
while (totalZ > 0) {
|
||||
dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
|
||||
cunn_VolumetricAdaptiveAveragePooling_updateOutput_kernel
|
||||
<<<blocks, threads, 0, THCState_getCurrentStream(state)>>>(
|
||||
input_data, output_data, isizeT, isizeH, isizeW, osizeT, osizeH, osizeW,
|
||||
istrideD, istrideT, istrideH, istrideW, offsetZ
|
||||
);
|
||||
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
THCudaCheck(cudaGetLastError());
|
||||
}
|
||||
|
||||
if (input->dim() == 5) {
|
||||
// clean
|
||||
THCTensor_(free)(state, input);
|
||||
}
|
||||
}
|
||||
|
||||
void THNN_(VolumetricAdaptiveAveragePooling_updateGradInput)(
|
||||
THCState *state,
|
||||
THCTensor *input,
|
||||
THCTensor *gradOutput,
|
||||
THCTensor *gradInput)
|
||||
{
|
||||
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
|
||||
|
||||
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
|
||||
|
||||
THCTensor_(resizeAs)(state, gradInput, input);
|
||||
THCTensor_(zero)(state, gradInput);
|
||||
|
||||
scalar_t *gradInput_data;
|
||||
scalar_t *gradOutput_data;
|
||||
|
||||
int64_t sizeD, isizeT, isizeH, isizeW;
|
||||
int64_t osizeT, osizeH, osizeW;
|
||||
int64_t totalZ;
|
||||
|
||||
if (input->dim() == 4) {
|
||||
sizeD = input->size(0);
|
||||
isizeT = input->size(1);
|
||||
isizeH = input->size(2);
|
||||
isizeW = input->size(3);
|
||||
|
||||
osizeT = gradOutput->size(1);
|
||||
osizeH = gradOutput->size(2);
|
||||
osizeW = gradOutput->size(3);
|
||||
} else {
|
||||
sizeD = input->size(1);
|
||||
isizeT = input->size(2);
|
||||
isizeH = input->size(3);
|
||||
isizeW = input->size(4);
|
||||
|
||||
osizeT = gradOutput->size(2);
|
||||
osizeH = gradOutput->size(3);
|
||||
osizeW = gradOutput->size(4);
|
||||
}
|
||||
|
||||
// somehow nonatomic is passing all test for volumetric case.
|
||||
bool atomic = false; //(isizeW%osizeW != 0) || (isizeH%osizeH != 0) || (isizeT%osizeT != 0);
|
||||
|
||||
if (input->dim() == 4) {
|
||||
totalZ = atomic ? sizeD * osizeT : sizeD * isizeT;
|
||||
} else {
|
||||
int sizeB = input->size(0);
|
||||
totalZ = atomic ? sizeB * sizeD * osizeT : sizeB * sizeD * isizeT;
|
||||
}
|
||||
|
||||
gradInput_data = THCTensor_(data)(state, gradInput);
|
||||
gradOutput_data = THCTensor_(data)(state, gradOutput);
|
||||
|
||||
int64_t offsetZ = 0;
|
||||
dim3 threads(32, 8);
|
||||
// each H*W plane is processed by blocksH thread blocks
|
||||
int blocksH = max((int)(16L / totalZ), 1);
|
||||
while (totalZ > 0) {
|
||||
dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
|
||||
|
||||
if (atomic)
|
||||
{
|
||||
cunn_atomic_VolumetricAdaptiveAveragePooling_updateGradInput_kernel
|
||||
<<<blocks, threads, 0, THCState_getCurrentStream(state)>>>(
|
||||
gradInput_data, gradOutput_data, isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW, offsetZ
|
||||
);
|
||||
} else {
|
||||
cunn_VolumetricAdaptiveAveragePooling_updateGradInput_kernel
|
||||
<<<blocks, threads, 0, THCState_getCurrentStream(state)>>>(
|
||||
gradInput_data, gradOutput_data, isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW, offsetZ
|
||||
);
|
||||
}
|
||||
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
THCudaCheck(cudaGetLastError());
|
||||
}
|
||||
// clean
|
||||
THCTensor_(free)(state, gradOutput);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@ -808,19 +808,6 @@ TH_API void THNN_(VolumetricMaxUnpooling_updateGradInput)(
|
||||
int dT, int dW, int dH,
|
||||
int pT, int pW, int pH);
|
||||
|
||||
TH_API void THNN_(VolumetricAdaptiveAveragePooling_updateOutput)(
|
||||
THNNState *state,
|
||||
THTensor *input,
|
||||
THTensor *output,
|
||||
int osizeT,
|
||||
int osizeW,
|
||||
int osizeH);
|
||||
TH_API void THNN_(VolumetricAdaptiveAveragePooling_updateGradInput)(
|
||||
THNNState *state,
|
||||
THTensor *input,
|
||||
THTensor *gradOutput,
|
||||
THTensor *gradInput);
|
||||
|
||||
TH_API void THNN_(FeatureLPPooling_updateOutput)(
|
||||
THNNState *state,
|
||||
THTensor *input,
|
||||
|
@ -1,305 +0,0 @@
|
||||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "THNN/generic/VolumetricAdaptiveAveragePooling.c"
|
||||
#else
|
||||
|
||||
#include <ATen/Parallel.h>
|
||||
|
||||
#define START_IND(a,b,c) (int)floor((float)(a * c) / b)
|
||||
#define END_IND(a,b,c) (int)ceil((float)((a + 1) * c) / b)
|
||||
// #define START_IND(a,b,c) a * c / b
|
||||
// #define END_IND(a,b,c) (a + 1) * c / b + ((a + 1) * c % b > 0)?1:0
|
||||
|
||||
// 5d tensor B x D x T x H x W
|
||||
|
||||
static void THNN_(VolumetricAdaptiveAveragePooling_updateOutput_frame)(
|
||||
scalar_t *input_p,
|
||||
scalar_t *output_p,
|
||||
int64_t sizeD,
|
||||
int64_t isizeT,
|
||||
int64_t isizeH,
|
||||
int64_t isizeW,
|
||||
int64_t osizeT,
|
||||
int64_t osizeH,
|
||||
int64_t osizeW,
|
||||
int64_t istrideD,
|
||||
int64_t istrideT,
|
||||
int64_t istrideH,
|
||||
int64_t istrideW)
|
||||
{
|
||||
at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) {
|
||||
for (auto d = start; d < end; d++)
|
||||
{
|
||||
/* loop over output */
|
||||
int64_t ot, oh, ow;
|
||||
for(ot = 0; ot < osizeT; ot++)
|
||||
{
|
||||
int istartT = START_IND(ot, osizeT, isizeT);
|
||||
int iendT = END_IND(ot, osizeT, isizeT);
|
||||
int kT = iendT - istartT;
|
||||
|
||||
for(oh = 0; oh < osizeH; oh++)
|
||||
{
|
||||
int istartH = START_IND(oh, osizeH, isizeH);
|
||||
int iendH = END_IND(oh, osizeH, isizeH);
|
||||
int kH = iendH - istartH;
|
||||
|
||||
for(ow = 0; ow < osizeW; ow++)
|
||||
{
|
||||
|
||||
int istartW = START_IND(ow, osizeW, isizeW);
|
||||
int iendW = END_IND(ow, osizeW, isizeW);
|
||||
int kW = iendW - istartW;
|
||||
|
||||
/* local pointers */
|
||||
scalar_t *ip = input_p + d*istrideD + istartT*istrideT + istartH*istrideH + istartW*istrideW;
|
||||
scalar_t *op = output_p + d*osizeT*osizeH*osizeW + ot*osizeH*osizeW + oh*osizeW + ow;
|
||||
|
||||
/* compute local average: */
|
||||
scalar_t sum = 0;
|
||||
int it, ih, iw;
|
||||
for(it = 0; it < kT; it++)
|
||||
{
|
||||
for(ih = 0; ih < kH; ih++)
|
||||
{
|
||||
for(iw = 0; iw < kW; iw++)
|
||||
{
|
||||
scalar_t val = *(ip + it*istrideT + ih*istrideH + iw*istrideW);
|
||||
sum += val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* set output to local average */
|
||||
*op = sum / kT / kH / kW;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void THNN_(VolumetricAdaptiveAveragePooling_updateOutput)(
|
||||
THNNState *state,
|
||||
THTensor *input,
|
||||
THTensor *output,
|
||||
int osizeT,
|
||||
int osizeW,
|
||||
int osizeH)
|
||||
{
|
||||
int dimD = 0;
|
||||
int dimT = 1;
|
||||
int dimH = 2;
|
||||
int dimW = 3;
|
||||
int64_t sizeB = 1;
|
||||
int64_t sizeD = 0;
|
||||
int64_t isizeT = 0;
|
||||
int64_t isizeH = 0;
|
||||
int64_t isizeW = 0;
|
||||
|
||||
int64_t istrideB = 0;
|
||||
int64_t istrideD = 0;
|
||||
int64_t istrideT = 0;
|
||||
int64_t istrideH = 0;
|
||||
int64_t istrideW = 0;
|
||||
|
||||
scalar_t *input_data = nullptr;
|
||||
scalar_t *output_data = nullptr;
|
||||
|
||||
|
||||
THNN_ARGCHECK(!input->is_empty() && (input->dim() == 4 || input->dim() == 5), 2, input,
|
||||
"non-empty 4D or 5D (batch mode) tensor expected for input, but got: %s");
|
||||
|
||||
if (input->dim() == 5)
|
||||
{
|
||||
istrideB = input->stride(0);
|
||||
sizeB = input->size(0);
|
||||
dimD++;
|
||||
dimT++;
|
||||
dimH++;
|
||||
dimW++;
|
||||
}
|
||||
|
||||
/* sizes */
|
||||
sizeD = input->size(dimD);
|
||||
isizeT = input->size(dimT);
|
||||
isizeH = input->size(dimH);
|
||||
isizeW = input->size(dimW);
|
||||
/* strides */
|
||||
istrideD = input->stride(dimD);
|
||||
istrideT = input->stride(dimT);
|
||||
istrideH = input->stride(dimH);
|
||||
istrideW = input->stride(dimW);
|
||||
|
||||
/* resize output */
|
||||
if (input->dim() == 4)
|
||||
{
|
||||
THTensor_(resize4d)(output, sizeD, osizeT, osizeH, osizeW);
|
||||
|
||||
input_data = input->data<scalar_t>();
|
||||
output_data = output->data<scalar_t>();
|
||||
|
||||
THNN_(VolumetricAdaptiveAveragePooling_updateOutput_frame)(input_data, output_data,
|
||||
sizeD,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW,
|
||||
istrideD, istrideT,
|
||||
istrideH, istrideW);
|
||||
}
|
||||
else
|
||||
{
|
||||
THTensor_(resize5d)(output, sizeB, sizeD, osizeT, osizeH, osizeW);
|
||||
|
||||
input_data = input->data<scalar_t>();
|
||||
output_data = output->data<scalar_t>();
|
||||
|
||||
at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) {
|
||||
for (auto b = start; b < end; b++)
|
||||
{
|
||||
THNN_(VolumetricAdaptiveAveragePooling_updateOutput_frame)(input_data+b*istrideB, output_data+b*sizeD*osizeT*osizeH*osizeW,
|
||||
sizeD,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW,
|
||||
istrideD, istrideT,
|
||||
istrideH, istrideW);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static void THNN_(VolumetricAdaptiveAveragePooling_updateGradInput_frame)(
|
||||
scalar_t *gradInput_p,
|
||||
scalar_t *gradOutput_p,
|
||||
int64_t sizeD,
|
||||
int64_t isizeT,
|
||||
int64_t isizeH,
|
||||
int64_t isizeW,
|
||||
int64_t osizeT,
|
||||
int64_t osizeH,
|
||||
int64_t osizeW)
|
||||
{
|
||||
at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) {
|
||||
for (auto d = start; d < end; d++)
|
||||
{
|
||||
scalar_t *gradInput_p_d = gradInput_p + d*isizeT*isizeW*isizeH;
|
||||
scalar_t *gradOutput_p_d = gradOutput_p + d*osizeT*osizeW*osizeH;
|
||||
|
||||
/* calculate average */
|
||||
int64_t ot, oh, ow;
|
||||
for(ot = 0; ot < osizeT; ot++)
|
||||
{
|
||||
int istartT = START_IND(ot, osizeT, isizeT);
|
||||
int iendT = END_IND(ot, osizeT, isizeT);
|
||||
int kT = iendT - istartT;
|
||||
|
||||
for(oh = 0; oh < osizeH; oh++)
|
||||
{
|
||||
int istartH = START_IND(oh, osizeH, isizeH);
|
||||
int iendH = END_IND(oh, osizeH, isizeH);
|
||||
int kH = iendH - istartH;
|
||||
|
||||
for(ow = 0; ow < osizeW; ow++)
|
||||
{
|
||||
|
||||
int istartW = START_IND(ow, osizeW, isizeW);
|
||||
int iendW = END_IND(ow, osizeW, isizeW);
|
||||
int kW = iendW - istartW;
|
||||
|
||||
scalar_t grad_delta = gradOutput_p_d[ot*osizeH*osizeW + oh*osizeW + ow] / kT / kH / kW;
|
||||
|
||||
int it, ih, iw;
|
||||
for(it = istartT; it < iendT; it++)
|
||||
{
|
||||
for(ih = istartH; ih < iendH; ih++)
|
||||
{
|
||||
for(iw = istartW; iw < iendW; iw++)
|
||||
{
|
||||
/* update gradient */
|
||||
gradInput_p_d[it*isizeH*isizeW + ih*isizeW + iw] += grad_delta;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void THNN_(VolumetricAdaptiveAveragePooling_updateGradInput)(
|
||||
THNNState *state,
|
||||
THTensor *input,
|
||||
THTensor *gradOutput,
|
||||
THTensor *gradInput)
|
||||
{
|
||||
int dimD = 0;
|
||||
int dimT = 1;
|
||||
int dimH = 2;
|
||||
int dimW = 3;
|
||||
int64_t sizeB = 1;
|
||||
int64_t sizeD;
|
||||
int64_t isizeT;
|
||||
int64_t isizeH;
|
||||
int64_t isizeW;
|
||||
int64_t osizeT;
|
||||
int64_t osizeH;
|
||||
int64_t osizeW;
|
||||
scalar_t *gradInput_data;
|
||||
scalar_t *gradOutput_data;
|
||||
|
||||
/* get contiguous gradOutput */
|
||||
gradOutput = THTensor_(newContiguous)(gradOutput);
|
||||
|
||||
/* resize */
|
||||
THTensor_(resizeAs)(gradInput, input);
|
||||
THTensor_(zero)(gradInput);
|
||||
|
||||
if (input->dim() == 5) {
|
||||
sizeB = input->size(0);
|
||||
dimD++;
|
||||
dimT++;
|
||||
dimH++;
|
||||
dimW++;
|
||||
}
|
||||
|
||||
/* sizes */
|
||||
sizeD = input->size(dimD);
|
||||
isizeT = input->size(dimT);
|
||||
isizeH = input->size(dimH);
|
||||
isizeW = input->size(dimW);
|
||||
osizeT = gradOutput->size(dimT);
|
||||
osizeH = gradOutput->size(dimH);
|
||||
osizeW = gradOutput->size(dimW);
|
||||
|
||||
/* get raw pointers */
|
||||
gradInput_data = gradInput->data<scalar_t>();
|
||||
gradOutput_data = gradOutput->data<scalar_t>();
|
||||
|
||||
/* backprop */
|
||||
if (input->dim() == 4)
|
||||
{
|
||||
THNN_(VolumetricAdaptiveAveragePooling_updateGradInput_frame)(gradInput_data, gradOutput_data,
|
||||
sizeD,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW);
|
||||
}
|
||||
else
|
||||
{
|
||||
at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) {
|
||||
for (auto b = start; b < end; b++)
|
||||
{
|
||||
THNN_(VolumetricAdaptiveAveragePooling_updateGradInput_frame)(gradInput_data+b*sizeD*isizeT*isizeH*isizeW, gradOutput_data+b*sizeD*osizeT*osizeH*osizeW,
|
||||
sizeD,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/* cleanup */
|
||||
c10::raw::intrusive_ptr::decref(gradOutput);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#undef START_IND
|
||||
#undef END_IND
|
@ -178,9 +178,6 @@
|
||||
#include <THNN/generic/VolumetricDilatedConvolution.c>
|
||||
#include <TH/THGenerateFloatTypes.h>
|
||||
|
||||
#include <THNN/generic/VolumetricAdaptiveAveragePooling.c>
|
||||
#include <TH/THGenerateFloatTypes.h>
|
||||
|
||||
#include <THNN/generic/VolumetricDilatedMaxPooling.c>
|
||||
#include <TH/THGenerateFloatTypes.h>
|
||||
|
||||
|
@ -283,7 +283,6 @@ def _generate_function_classes(scope_dict):
|
||||
'VolumetricAveragePooling',
|
||||
'VolumetricMaxPooling',
|
||||
'VolumetricMaxUnpooling',
|
||||
'VolumetricAdaptiveAveragePooling',
|
||||
'VolumetricConvolution',
|
||||
'VolumetricFullConvolution',
|
||||
'VolumetricConvolutionMM',
|
||||
|
Reference in New Issue
Block a user