mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
optimization on cpu conv3d (#11884)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11884 In cpu mode, current convNd uses Im2ColNdNCHWImpl, which is generic implementation to handle convolutional layer for arbitrary number of dimensions. In video modeling, we use convNd for filter dimension=3. The problem of current convNd is that Im2ColNdNCHWImpl is much slower than Im2Col used by conv2d for the filters with same Flops. For example, a (1, 7, 7) 3d filter takes 5 times longer than a (7, 7) 2d filter at inference time. This diff extends Im2Col to 3d case (Im2Col3dNCHWImpl), and this optimization for 3d convolution gives 4~5 times faster inference time on cpu for various video models: {F128300920} i-am-not-moving-c2-to-c10 Reviewed By: BIT-silence Differential Revision: D8245940 fbshipit-source-id: 75231d65c9dd56059dfe31701e26021fd1ff2a85
This commit is contained in:
committed by
Facebook Github Bot
parent
d714ecf879
commit
d843f63f2a
@ -2994,6 +2994,170 @@ C10_EXPORT void Im2ColNdNCHWImpl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Im2Col3dNCHWImpl(
|
||||
const int channels,
|
||||
const int clip_len,
|
||||
const int height,
|
||||
const int width,
|
||||
const int kernel_t,
|
||||
const int kernel_h,
|
||||
const int kernel_w,
|
||||
const int dilation_t,
|
||||
const int dilation_h,
|
||||
const int dilation_w,
|
||||
const int pad_p,
|
||||
const int pad_t,
|
||||
const int pad_l,
|
||||
const int pad_a,
|
||||
const int pad_b,
|
||||
const int pad_r,
|
||||
const int stride_t,
|
||||
const int stride_h,
|
||||
const int stride_w,
|
||||
const T* img_data,
|
||||
T* col_data) {
|
||||
const int output_t =
|
||||
(clip_len + pad_p + pad_a - (dilation_t * (kernel_t - 1) + 1)) /
|
||||
stride_t +
|
||||
1;
|
||||
const int output_h =
|
||||
(height + pad_b + pad_t - (dilation_h * (kernel_h - 1) + 1)) / stride_h +
|
||||
1;
|
||||
const int output_w =
|
||||
(width + pad_l + pad_r - (dilation_w * (kernel_w - 1) + 1)) / stride_w +
|
||||
1;
|
||||
const int kernel_size = kernel_t * kernel_h * kernel_w;
|
||||
const int kernel_hw_size = kernel_h * kernel_w;
|
||||
const int output_size = output_t * output_h * output_w;
|
||||
const int channel_size = clip_len * height * width;
|
||||
const int output_hw_size = output_h * output_w;
|
||||
const int channel_hw_size = height * width;
|
||||
|
||||
// Fast path for zero padding and no dilation
|
||||
// From Torch, THNN_(unfolded_copy)
|
||||
if (dilation_t == 1 && dilation_h == 1 && dilation_w == 1 && pad_a == 0 &&
|
||||
pad_p == 0 && pad_l == 0 && pad_r == 0 && pad_t == 0 && pad_b == 0) {
|
||||
for (auto k = 0; k < channels * kernel_size; k++) {
|
||||
const auto nip = k / kernel_size;
|
||||
const auto rest = k % kernel_size;
|
||||
const auto kt = rest / kernel_hw_size;
|
||||
const auto rest_hw = rest % kernel_hw_size;
|
||||
const auto kh = rest_hw / kernel_w;
|
||||
const auto kw = rest_hw % kernel_w;
|
||||
auto* dst = col_data + nip * (kernel_size * output_size) +
|
||||
kt * (kernel_hw_size * output_size) + kh * (kernel_w * output_size) +
|
||||
kw * output_size;
|
||||
const auto* src = img_data + nip * channel_size;
|
||||
for (auto t = 0; t < output_t; t++) {
|
||||
const auto it = t * stride_t + kt;
|
||||
for (auto y = 0; y < output_h; y++) {
|
||||
const auto iy = y * stride_h + kh;
|
||||
const auto ix = kw;
|
||||
if (stride_w == 1) {
|
||||
memcpy(
|
||||
dst + (t * output_hw_size + y * output_w),
|
||||
src + (it * channel_hw_size + iy * width + ix),
|
||||
sizeof(T) * output_w);
|
||||
} else {
|
||||
for (auto x = 0; x < output_w; x++) {
|
||||
memcpy(
|
||||
dst + (t * output_hw_size + y * output_w + x),
|
||||
src + (it * channel_hw_size + iy * width + ix + x * stride_w),
|
||||
sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Fast path for equal padding
|
||||
if (pad_a == pad_p && pad_l == pad_r && pad_t == pad_b) {
|
||||
const int pad_f = pad_a;
|
||||
const int pad_h = pad_t;
|
||||
const int pad_w = pad_l;
|
||||
for (int channel = channels; channel--; img_data += channel_size) {
|
||||
for (int kernel_frame = 0; kernel_frame < kernel_t; kernel_frame++) {
|
||||
for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
|
||||
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
|
||||
int input_frame = -pad_f + kernel_frame * dilation_t;
|
||||
for (int output_frames = output_t; output_frames; output_frames--) {
|
||||
if (!utils::IsAGeZeroAndALtB(input_frame, clip_len)) {
|
||||
for (int output_rows = output_h; output_rows; output_rows--) {
|
||||
for (int output_cols = output_w; output_cols; output_cols--) {
|
||||
*(col_data++) = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int input_row = -pad_h + kernel_row * dilation_h;
|
||||
for (int output_rows = output_h; output_rows; output_rows--) {
|
||||
if (!utils::IsAGeZeroAndALtB(input_row, height)) {
|
||||
for (int output_cols = output_w; output_cols;
|
||||
output_cols--) {
|
||||
*(col_data++) = 0;
|
||||
}
|
||||
} else {
|
||||
int input_col = -pad_w + kernel_col * dilation_w;
|
||||
for (int output_col = output_w; output_col; output_col--) {
|
||||
if (utils::IsAGeZeroAndALtB(input_col, width)) {
|
||||
*(col_data++) = img_data
|
||||
[(input_frame * height + input_row) * width +
|
||||
input_col];
|
||||
} else {
|
||||
*(col_data++) = 0;
|
||||
}
|
||||
input_col += stride_w;
|
||||
}
|
||||
}
|
||||
input_row += stride_h;
|
||||
}
|
||||
}
|
||||
input_frame += stride_t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Baseline
|
||||
const int dkernel_t = dilation_t * (kernel_t - 1) + 1;
|
||||
const int dkernel_h = dilation_h * (kernel_h - 1) + 1;
|
||||
const int dkernel_w = dilation_w * (kernel_w - 1) + 1;
|
||||
|
||||
int clip_col = (clip_len + pad_p + pad_a - dkernel_t) / stride_t + 1;
|
||||
int height_col = (height + pad_t + pad_b - dkernel_h) / stride_h + 1;
|
||||
int width_col = (width + pad_l + pad_r - dkernel_w) / stride_w + 1;
|
||||
|
||||
int channels_col = channels * kernel_t * kernel_h * kernel_w;
|
||||
for (int c = 0; c < channels_col; ++c) {
|
||||
int w_offset = c % kernel_w;
|
||||
int h_offset = (c / kernel_w) % kernel_h;
|
||||
int t_offset = (c / kernel_w / kernel_h) % kernel_t;
|
||||
int c_im = c / kernel_h / kernel_w / kernel_t;
|
||||
for (int t = 0; t < clip_col; ++t) {
|
||||
for (int h = 0; h < height_col; ++h) {
|
||||
for (int w = 0; w < width_col; ++w) {
|
||||
int t_pad = t * stride_t - pad_p + t_offset * dilation_t;
|
||||
int h_pad = h * stride_h - pad_t + h_offset * dilation_h;
|
||||
int w_pad = w * stride_w - pad_l + w_offset * dilation_w;
|
||||
if (t_pad >= 0 && t_pad < clip_len && h_pad >= 0 && h_pad < height &&
|
||||
w_pad >= 0 && w_pad < width) {
|
||||
col_data[((c * clip_col + t) * height_col + h) * width_col + w] =
|
||||
img_data
|
||||
[((c_im * clip_len + t_pad) * height + h_pad) * width +
|
||||
w_pad];
|
||||
} else {
|
||||
col_data[((c * clip_col + t) * height_col + h) * width_col + w] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <>
|
||||
@ -3010,18 +3174,45 @@ C10_EXPORT void Im2ColNd<float, CPUContext, StorageOrder::NCHW>(
|
||||
const float* img_data,
|
||||
float* col_data,
|
||||
CPUContext* /* context */) {
|
||||
Im2ColNdNCHWImpl<float, false>(
|
||||
N,
|
||||
img_size,
|
||||
col_size,
|
||||
img_shape,
|
||||
col_shape,
|
||||
kernel_shape,
|
||||
stride,
|
||||
dilation,
|
||||
pad,
|
||||
img_data,
|
||||
col_data);
|
||||
if (N == 3) {
|
||||
const int channels =
|
||||
col_shape[0] / kernel_shape[0] / kernel_shape[1] / kernel_shape[2];
|
||||
Im2Col3dNCHWImpl<float>(
|
||||
channels,
|
||||
img_shape[1],
|
||||
img_shape[2],
|
||||
img_shape[3],
|
||||
kernel_shape[0],
|
||||
kernel_shape[1],
|
||||
kernel_shape[2],
|
||||
dilation[0],
|
||||
dilation[1],
|
||||
dilation[2],
|
||||
pad[0],
|
||||
pad[1],
|
||||
pad[2],
|
||||
pad[3],
|
||||
pad[4],
|
||||
pad[5],
|
||||
stride[0],
|
||||
stride[1],
|
||||
stride[2],
|
||||
img_data,
|
||||
col_data);
|
||||
} else {
|
||||
Im2ColNdNCHWImpl<float, false>(
|
||||
N,
|
||||
img_size,
|
||||
col_size,
|
||||
img_shape,
|
||||
col_shape,
|
||||
kernel_shape,
|
||||
stride,
|
||||
dilation,
|
||||
pad,
|
||||
img_data,
|
||||
col_data);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
|
Reference in New Issue
Block a user