From 48424a6c94143e455cbd1116a7c7726ac5cf559e Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 23 May 2019 13:21:34 -0700 Subject: [PATCH] Avoid dynamic dispatch inside the omp loop in AdaptiveAvgPool2d (#20366) Summary: This PR changes CPU implementation of `AdaptiveAveragePool2D` by - move dispatch to outside the OpenMP loop - support fp16 Pull Request resolved: https://github.com/pytorch/pytorch/pull/20366 Differential Revision: D15456069 Pulled By: ezyang fbshipit-source-id: 00fa2916f8b136af9f5c8b5db0eca4619f9f5bac --- .../ATen/native/AdaptiveAveragePooling.cpp | 144 ++++++++++++------ 1 file changed, 99 insertions(+), 45 deletions(-) diff --git a/aten/src/ATen/native/AdaptiveAveragePooling.cpp b/aten/src/ATen/native/AdaptiveAveragePooling.cpp index e2badf7d6f05..eec9c755bba0 100644 --- a/aten/src/ATen/native/AdaptiveAveragePooling.cpp +++ b/aten/src/ATen/native/AdaptiveAveragePooling.cpp @@ -1,5 +1,5 @@ -#include "ATen/ATen.h" -#include "ATen/NativeFunctions.h" +#include +#include #include #include @@ -18,7 +18,7 @@ namespace { } template - static void adaptive_avg_pool2d_out_frame( + static void adaptive_avg_pool2d_single_out_frame( scalar_t *input_p, scalar_t *output_p, int64_t sizeD, @@ -71,6 +71,36 @@ namespace { }); } + template + void adaptive_avg_pool2d_out_frame( + scalar_t *input_p, + scalar_t *output_p, + int64_t sizeB, + int64_t sizeD, + int64_t isizeH, + int64_t isizeW, + int64_t osizeH, + int64_t osizeW, + int64_t istrideB, + int64_t istrideD, + int64_t istrideH, + int64_t istrideW) + { + at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) { + for (auto b = start; b < end; b++) + { + adaptive_avg_pool2d_single_out_frame( + input_p + b * istrideB, + output_p + b * sizeD * osizeH * osizeW, + sizeD, + isizeH, isizeW, + osizeH, osizeW, + istrideD, + istrideH, istrideW); + } + }); + } + void adaptive_avg_pool2d_out_cpu_template( at::Tensor& output, at::Tensor const& input, @@ -103,43 +133,45 @@ namespace { { output.resize_({sizeD, osizeH, osizeW}); - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] { auto input_data = input.data(); auto output_data = output.data(); - adaptive_avg_pool2d_out_frame(input_data, output_data, - sizeD, - isizeH, isizeW, - osizeH, osizeW, - istrideD, - istrideH, istrideW); + adaptive_avg_pool2d_single_out_frame( + input_data, + output_data, + sizeD, + isizeH, isizeW, + osizeH, osizeW, + istrideD, + istrideH, istrideW); } ); } else { - output.resize_({input.size(-4), sizeD, osizeH, osizeW}); + int64_t sizeB = input.size(-4); + output.resize_({sizeB, sizeD, osizeH, osizeW}); + int64_t istrideB = input.stride(-4); - at::parallel_for(0, input.size(0), 0, [&](int64_t start, int64_t end) { - for (auto b = start; b < end; b++) - { - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] { - auto input_data = input.data(); - auto output_data = output.data(); - adaptive_avg_pool2d_out_frame(input_data+b*input.stride(0), output_data+b*sizeD*osizeH*osizeW, - sizeD, - isizeH, isizeW, - osizeH, osizeW, - istrideD, - istrideH, istrideW); - } - ); - } + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] { + auto input_data = input.data(); + auto output_data = output.data(); + adaptive_avg_pool2d_out_frame( + input_data, + output_data, + sizeB, + sizeD, + isizeH, isizeW, + osizeH, osizeW, + istrideB, + istrideD, + istrideH, istrideW); }); } } template - static void adaptive_avg_pool2d_backward_out_frame( + static void adaptive_avg_pool2d_backward_single_out_frame( scalar_t *gradInput_p, scalar_t *gradOutput_p, int64_t sizeD, @@ -186,6 +218,32 @@ namespace { }); } + template + void adaptive_avg_pool2d_backward_out_frame( + scalar_t *gradInput_p, + scalar_t *gradOutput_p, + int64_t sizeB, + int64_t sizeD, + int64_t isizeH, + int64_t isizeW, + int64_t osizeH, + int64_t osizeW) + { + at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) { + for (auto b = start; b < end; b++) + { + scalar_t *gradInput_p_d = gradInput_p + b * sizeD * isizeW * isizeH; + scalar_t *gradOutput_p_d = gradOutput_p + b * sizeD * osizeW * osizeH; + adaptive_avg_pool2d_backward_single_out_frame( + gradInput_p_d, + gradOutput_p_d, + sizeD, + isizeH, isizeW, + osizeH, osizeW); + } + }); + } + Tensor& adaptive_avg_pool2d_backward_out_cpu_template( Tensor& gradInput, const Tensor& gradOutput_, @@ -204,13 +262,13 @@ namespace { /* backprop */ if (input.ndimension() == 3) { - AT_DISPATCH_FLOATING_TYPES( + AT_DISPATCH_FLOATING_TYPES_AND_HALF( input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] { /* get raw pointers */ scalar_t *gradInput_data = gradInput.data(); scalar_t *gradOutput_data = gradOutput.data(); - adaptive_avg_pool2d_backward_out_frame( + adaptive_avg_pool2d_backward_single_out_frame( gradInput_data, gradOutput_data, sizeD, isizeH, isizeW, @@ -220,24 +278,20 @@ namespace { } else { - at::parallel_for(0, input.size(0), 0, [&](int64_t start, int64_t end) { - for (auto b = start; b < end; b++) - { - AT_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] { - /* get raw pointers */ - scalar_t *gradInput_data = gradInput.data(); - scalar_t *gradOutput_data = gradOutput.data(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] { + /* get raw pointers */ + scalar_t *gradInput_data = gradInput.data(); + scalar_t *gradOutput_data = gradOutput.data(); + int64_t sizeB = input.size(-4); - adaptive_avg_pool2d_backward_out_frame( - gradInput_data+b*sizeD*isizeH*isizeW, gradOutput_data+b*sizeD*osizeH*osizeW, - sizeD, - isizeH, isizeW, - osizeH, osizeW); - } - ); + adaptive_avg_pool2d_backward_out_frame( + gradInput_data, gradOutput_data, + sizeB, sizeD, + isizeH, isizeW, + osizeH, osizeW); } - }); + ); } return gradInput; }