Avoid dynamic dispatch inside the omp loop in AdaptiveAvgPool2d (#20366)

Summary:
This PR changes CPU implementation of `AdaptiveAveragePool2D` by
- move dispatch to outside the OpenMP loop
- support fp16
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20366

Differential Revision: D15456069

Pulled By: ezyang

fbshipit-source-id: 00fa2916f8b136af9f5c8b5db0eca4619f9f5bac
This commit is contained in:
Masaki Kozuki
2019-05-23 13:21:34 -07:00
committed by Facebook Github Bot
parent cf0268e51c
commit 48424a6c94

View File

@ -1,5 +1,5 @@
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <tuple>
@ -18,7 +18,7 @@ namespace {
}
template <typename scalar_t>
static void adaptive_avg_pool2d_out_frame(
static void adaptive_avg_pool2d_single_out_frame(
scalar_t *input_p,
scalar_t *output_p,
int64_t sizeD,
@ -71,6 +71,36 @@ namespace {
});
}
template <typename scalar_t>
void adaptive_avg_pool2d_out_frame(
scalar_t *input_p,
scalar_t *output_p,
int64_t sizeB,
int64_t sizeD,
int64_t isizeH,
int64_t isizeW,
int64_t osizeH,
int64_t osizeW,
int64_t istrideB,
int64_t istrideD,
int64_t istrideH,
int64_t istrideW)
{
at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) {
for (auto b = start; b < end; b++)
{
adaptive_avg_pool2d_single_out_frame<scalar_t>(
input_p + b * istrideB,
output_p + b * sizeD * osizeH * osizeW,
sizeD,
isizeH, isizeW,
osizeH, osizeW,
istrideD,
istrideH, istrideW);
}
});
}
void adaptive_avg_pool2d_out_cpu_template(
at::Tensor& output,
at::Tensor const& input,
@ -103,43 +133,45 @@ namespace {
{
output.resize_({sizeD, osizeH, osizeW});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] {
auto input_data = input.data<scalar_t>();
auto output_data = output.data<scalar_t>();
adaptive_avg_pool2d_out_frame<scalar_t>(input_data, output_data,
sizeD,
isizeH, isizeW,
osizeH, osizeW,
istrideD,
istrideH, istrideW);
adaptive_avg_pool2d_single_out_frame<scalar_t>(
input_data,
output_data,
sizeD,
isizeH, isizeW,
osizeH, osizeW,
istrideD,
istrideH, istrideW);
}
);
}
else
{
output.resize_({input.size(-4), sizeD, osizeH, osizeW});
int64_t sizeB = input.size(-4);
output.resize_({sizeB, sizeD, osizeH, osizeW});
int64_t istrideB = input.stride(-4);
at::parallel_for(0, input.size(0), 0, [&](int64_t start, int64_t end) {
for (auto b = start; b < end; b++)
{
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] {
auto input_data = input.data<scalar_t>();
auto output_data = output.data<scalar_t>();
adaptive_avg_pool2d_out_frame<scalar_t>(input_data+b*input.stride(0), output_data+b*sizeD*osizeH*osizeW,
sizeD,
isizeH, isizeW,
osizeH, osizeW,
istrideD,
istrideH, istrideW);
}
);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] {
auto input_data = input.data<scalar_t>();
auto output_data = output.data<scalar_t>();
adaptive_avg_pool2d_out_frame<scalar_t>(
input_data,
output_data,
sizeB,
sizeD,
isizeH, isizeW,
osizeH, osizeW,
istrideB,
istrideD,
istrideH, istrideW);
});
}
}
template <typename scalar_t>
static void adaptive_avg_pool2d_backward_out_frame(
static void adaptive_avg_pool2d_backward_single_out_frame(
scalar_t *gradInput_p,
scalar_t *gradOutput_p,
int64_t sizeD,
@ -186,6 +218,32 @@ namespace {
});
}
template <typename scalar_t>
void adaptive_avg_pool2d_backward_out_frame(
scalar_t *gradInput_p,
scalar_t *gradOutput_p,
int64_t sizeB,
int64_t sizeD,
int64_t isizeH,
int64_t isizeW,
int64_t osizeH,
int64_t osizeW)
{
at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) {
for (auto b = start; b < end; b++)
{
scalar_t *gradInput_p_d = gradInput_p + b * sizeD * isizeW * isizeH;
scalar_t *gradOutput_p_d = gradOutput_p + b * sizeD * osizeW * osizeH;
adaptive_avg_pool2d_backward_single_out_frame<scalar_t>(
gradInput_p_d,
gradOutput_p_d,
sizeD,
isizeH, isizeW,
osizeH, osizeW);
}
});
}
Tensor& adaptive_avg_pool2d_backward_out_cpu_template(
Tensor& gradInput,
const Tensor& gradOutput_,
@ -204,13 +262,13 @@ namespace {
/* backprop */
if (input.ndimension() == 3)
{
AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] {
/* get raw pointers */
scalar_t *gradInput_data = gradInput.data<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
adaptive_avg_pool2d_backward_out_frame<scalar_t>(
adaptive_avg_pool2d_backward_single_out_frame<scalar_t>(
gradInput_data, gradOutput_data,
sizeD,
isizeH, isizeW,
@ -220,24 +278,20 @@ namespace {
}
else
{
at::parallel_for(0, input.size(0), 0, [&](int64_t start, int64_t end) {
for (auto b = start; b < end; b++)
{
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] {
/* get raw pointers */
scalar_t *gradInput_data = gradInput.data<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] {
/* get raw pointers */
scalar_t *gradInput_data = gradInput.data<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
int64_t sizeB = input.size(-4);
adaptive_avg_pool2d_backward_out_frame<scalar_t>(
gradInput_data+b*sizeD*isizeH*isizeW, gradOutput_data+b*sizeD*osizeH*osizeW,
sizeD,
isizeH, isizeW,
osizeH, osizeW);
}
);
adaptive_avg_pool2d_backward_out_frame<scalar_t>(
gradInput_data, gradOutput_data,
sizeB, sizeD,
isizeH, isizeW,
osizeH, osizeW);
}
});
);
}
return gradInput;
}