optimization on cpu conv3d (#11884)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11884

In cpu mode, current convNd uses Im2ColNdNCHWImpl, which is generic implementation to handle convolutional layer for arbitrary number of dimensions. In video modeling, we use convNd for filter dimension=3.

The problem of current convNd is that Im2ColNdNCHWImpl is much slower than Im2Col used by conv2d for the filters with same Flops. For example, a (1, 7, 7) 3d filter takes 5 times longer than a (7, 7) 2d filter at inference time.

This diff extends Im2Col to 3d case (Im2Col3dNCHWImpl), and this optimization for 3d convolution gives 4~5 times faster inference time on cpu for various video models:

{F128300920}

i-am-not-moving-c2-to-c10

Reviewed By: BIT-silence

Differential Revision: D8245940

fbshipit-source-id: 75231d65c9dd56059dfe31701e26021fd1ff2a85
This commit is contained in:
Yufei Wang
2018-11-01 15:07:42 -07:00
committed by Facebook Github Bot
parent d714ecf879
commit d843f63f2a

View File

@ -2994,6 +2994,170 @@ C10_EXPORT void Im2ColNdNCHWImpl(
}
}
template <typename T>
void Im2Col3dNCHWImpl(
const int channels,
const int clip_len,
const int height,
const int width,
const int kernel_t,
const int kernel_h,
const int kernel_w,
const int dilation_t,
const int dilation_h,
const int dilation_w,
const int pad_p,
const int pad_t,
const int pad_l,
const int pad_a,
const int pad_b,
const int pad_r,
const int stride_t,
const int stride_h,
const int stride_w,
const T* img_data,
T* col_data) {
const int output_t =
(clip_len + pad_p + pad_a - (dilation_t * (kernel_t - 1) + 1)) /
stride_t +
1;
const int output_h =
(height + pad_b + pad_t - (dilation_h * (kernel_h - 1) + 1)) / stride_h +
1;
const int output_w =
(width + pad_l + pad_r - (dilation_w * (kernel_w - 1) + 1)) / stride_w +
1;
const int kernel_size = kernel_t * kernel_h * kernel_w;
const int kernel_hw_size = kernel_h * kernel_w;
const int output_size = output_t * output_h * output_w;
const int channel_size = clip_len * height * width;
const int output_hw_size = output_h * output_w;
const int channel_hw_size = height * width;
// Fast path for zero padding and no dilation
// From Torch, THNN_(unfolded_copy)
if (dilation_t == 1 && dilation_h == 1 && dilation_w == 1 && pad_a == 0 &&
pad_p == 0 && pad_l == 0 && pad_r == 0 && pad_t == 0 && pad_b == 0) {
for (auto k = 0; k < channels * kernel_size; k++) {
const auto nip = k / kernel_size;
const auto rest = k % kernel_size;
const auto kt = rest / kernel_hw_size;
const auto rest_hw = rest % kernel_hw_size;
const auto kh = rest_hw / kernel_w;
const auto kw = rest_hw % kernel_w;
auto* dst = col_data + nip * (kernel_size * output_size) +
kt * (kernel_hw_size * output_size) + kh * (kernel_w * output_size) +
kw * output_size;
const auto* src = img_data + nip * channel_size;
for (auto t = 0; t < output_t; t++) {
const auto it = t * stride_t + kt;
for (auto y = 0; y < output_h; y++) {
const auto iy = y * stride_h + kh;
const auto ix = kw;
if (stride_w == 1) {
memcpy(
dst + (t * output_hw_size + y * output_w),
src + (it * channel_hw_size + iy * width + ix),
sizeof(T) * output_w);
} else {
for (auto x = 0; x < output_w; x++) {
memcpy(
dst + (t * output_hw_size + y * output_w + x),
src + (it * channel_hw_size + iy * width + ix + x * stride_w),
sizeof(T));
}
}
}
}
}
return;
}
// Fast path for equal padding
if (pad_a == pad_p && pad_l == pad_r && pad_t == pad_b) {
const int pad_f = pad_a;
const int pad_h = pad_t;
const int pad_w = pad_l;
for (int channel = channels; channel--; img_data += channel_size) {
for (int kernel_frame = 0; kernel_frame < kernel_t; kernel_frame++) {
for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_frame = -pad_f + kernel_frame * dilation_t;
for (int output_frames = output_t; output_frames; output_frames--) {
if (!utils::IsAGeZeroAndALtB(input_frame, clip_len)) {
for (int output_rows = output_h; output_rows; output_rows--) {
for (int output_cols = output_w; output_cols; output_cols--) {
*(col_data++) = 0;
}
}
} else {
int input_row = -pad_h + kernel_row * dilation_h;
for (int output_rows = output_h; output_rows; output_rows--) {
if (!utils::IsAGeZeroAndALtB(input_row, height)) {
for (int output_cols = output_w; output_cols;
output_cols--) {
*(col_data++) = 0;
}
} else {
int input_col = -pad_w + kernel_col * dilation_w;
for (int output_col = output_w; output_col; output_col--) {
if (utils::IsAGeZeroAndALtB(input_col, width)) {
*(col_data++) = img_data
[(input_frame * height + input_row) * width +
input_col];
} else {
*(col_data++) = 0;
}
input_col += stride_w;
}
}
input_row += stride_h;
}
}
input_frame += stride_t;
}
}
}
}
}
return;
}
// Baseline
const int dkernel_t = dilation_t * (kernel_t - 1) + 1;
const int dkernel_h = dilation_h * (kernel_h - 1) + 1;
const int dkernel_w = dilation_w * (kernel_w - 1) + 1;
int clip_col = (clip_len + pad_p + pad_a - dkernel_t) / stride_t + 1;
int height_col = (height + pad_t + pad_b - dkernel_h) / stride_h + 1;
int width_col = (width + pad_l + pad_r - dkernel_w) / stride_w + 1;
int channels_col = channels * kernel_t * kernel_h * kernel_w;
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % kernel_w;
int h_offset = (c / kernel_w) % kernel_h;
int t_offset = (c / kernel_w / kernel_h) % kernel_t;
int c_im = c / kernel_h / kernel_w / kernel_t;
for (int t = 0; t < clip_col; ++t) {
for (int h = 0; h < height_col; ++h) {
for (int w = 0; w < width_col; ++w) {
int t_pad = t * stride_t - pad_p + t_offset * dilation_t;
int h_pad = h * stride_h - pad_t + h_offset * dilation_h;
int w_pad = w * stride_w - pad_l + w_offset * dilation_w;
if (t_pad >= 0 && t_pad < clip_len && h_pad >= 0 && h_pad < height &&
w_pad >= 0 && w_pad < width) {
col_data[((c * clip_col + t) * height_col + h) * width_col + w] =
img_data
[((c_im * clip_len + t_pad) * height + h_pad) * width +
w_pad];
} else {
col_data[((c * clip_col + t) * height_col + h) * width_col + w] = 0;
}
}
}
}
}
}
} // namespace
template <>
@ -3010,18 +3174,45 @@ C10_EXPORT void Im2ColNd<float, CPUContext, StorageOrder::NCHW>(
const float* img_data,
float* col_data,
CPUContext* /* context */) {
Im2ColNdNCHWImpl<float, false>(
N,
img_size,
col_size,
img_shape,
col_shape,
kernel_shape,
stride,
dilation,
pad,
img_data,
col_data);
if (N == 3) {
const int channels =
col_shape[0] / kernel_shape[0] / kernel_shape[1] / kernel_shape[2];
Im2Col3dNCHWImpl<float>(
channels,
img_shape[1],
img_shape[2],
img_shape[3],
kernel_shape[0],
kernel_shape[1],
kernel_shape[2],
dilation[0],
dilation[1],
dilation[2],
pad[0],
pad[1],
pad[2],
pad[3],
pad[4],
pad[5],
stride[0],
stride[1],
stride[2],
img_data,
col_data);
} else {
Im2ColNdNCHWImpl<float, false>(
N,
img_size,
col_size,
img_shape,
col_shape,
kernel_shape,
stride,
dilation,
pad,
img_data,
col_data);
}
}
template <>