improve batch_norm contiguous case's performance (#34530)

Summary:
For batch_norm inference contiguous case, we can get a better performance by manually vectorize it.
Test script:
```                                                                                                   X
 import torch
 import torch.nn as nn
 import time

 torch.manual_seed(0)

 for n in [1, 10, 100]:
     for c in [1, 10, 100]:
         for hw in [1, 10, 200]:
             m = nn.BatchNorm2d(c, affine=False)
             m.eval()
             input = torch.randn(20, c, hw, hw)
             # warm up
             for i in range(200):
                 output = m(input)
             fwd_t = 0
             for j in range(1000):
                 t1 = time.time()
                 output = m(input)
                 t2 = time.time()
                 fwd_t = fwd_t + (t2 -t1)

             fwd_avg = fwd_t / 1000 * 1000
             print("size = (%d, %d, %d, %d); compute time is %.4f(ms)" % (n, c, hw, hw, fwd_avg))
```

Before:
```
size = (1, 1, 1, 1); compute time is 0.0110(ms)
size = (1, 1, 10, 10); compute time is 0.0123(ms)
size = (1, 1, 200, 200); compute time is 0.8166(ms)
size = (1, 10, 1, 1); compute time is 0.0107(ms)
size = (1, 10, 10, 10); compute time is 0.0257(ms)
size = (1, 10, 200, 200); compute time is 8.7533(ms)
size = (1, 100, 1, 1); compute time is 0.0122(ms)
size = (1, 100, 10, 10); compute time is 0.1619(ms)
size = (1, 100, 200, 200); compute time is 123.5674(ms)
size = (10, 1, 1, 1); compute time is 0.0109(ms)
size = (10, 1, 10, 10); compute time is 0.0123(ms)
size = (10, 1, 200, 200); compute time is 0.5629(ms)
size = (10, 10, 1, 1); compute time is 0.0107(ms)
size = (10, 10, 10, 10); compute time is 0.0253(ms)
size = (10, 10, 200, 200); compute time is 8.7817(ms)
size = (10, 100, 1, 1); compute time is 0.0120(ms)
size = (10, 100, 10, 10); compute time is 0.1655(ms)
size = (10, 100, 200, 200); compute time is 123.2488(ms)
size = (100, 1, 1, 1); compute time is 0.0109(ms)
size = (100, 1, 10, 10); compute time is 0.0123(ms)
size = (100, 1, 200, 200); compute time is 0.5740(ms)
size = (100, 10, 1, 1); compute time is 0.0108(ms)
size = (100, 10, 10, 10); compute time is 0.0257(ms)
size = (100, 10, 200, 200); compute time is 8.7201(ms)
size = (100, 100, 1, 1); compute time is 0.0122(ms)
size = (100, 100, 10, 10); compute time is 0.1628(ms)
size = (100, 100, 200, 200); compute time is 123.1739(ms)
```
After:
```
size = (1, 1, 1, 1); compute time is 0.0105(ms)
size = (1, 1, 10, 10); compute time is 0.0114(ms)
size = (1, 1, 200, 200); compute time is 0.5771(ms)
size = (1, 10, 1, 1); compute time is 0.0105(ms)
size = (1, 10, 10, 10); compute time is 0.0160(ms)
size = (1, 10, 200, 200); compute time is 6.9851(ms)
size = (1, 100, 1, 1); compute time is 0.0122(ms)
size = (1, 100, 10, 10); compute time is 0.0848(ms)
size = (1, 100, 200, 200); compute time is 98.6758(ms)
size = (10, 1, 1, 1); compute time is 0.0105(ms)
size = (10, 1, 10, 10); compute time is 0.0115(ms)
size = (10, 1, 200, 200); compute time is 0.2690(ms)
size = (10, 10, 1, 1); compute time is 0.0105(ms)
size = (10, 10, 10, 10); compute time is 0.0159(ms)
size = (10, 10, 200, 200); compute time is 6.6946(ms)
size = (10, 100, 1, 1); compute time is 0.0123(ms)
size = (10, 100, 10, 10); compute time is 0.0854(ms)
size = (10, 100, 200, 200); compute time is 98.7327(ms)
size = (100, 1, 1, 1); compute time is 0.0107(ms)
size = (100, 1, 10, 10); compute time is 0.0116(ms)
size = (100, 1, 200, 200); compute time is 0.2681(ms)
size = (100, 10, 1, 1); compute time is 0.0104(ms)
size = (100, 10, 10, 10); compute time is 0.0159(ms)
size = (100, 10, 200, 200); compute time is 6.7507(ms)
size = (100, 100, 1, 1); compute time is 0.0124(ms)
size = (100, 100, 10, 10); compute time is 0.0852(ms)
size = (100, 100, 200, 200); compute time is 98.6866(ms)
```
For real modle Resnext101, we can also get **~20%** performance improvement for large batch size,
Test script:
```
 import torch
 import torchvision
 import torch
 import time

 torch.manual_seed(0)
 #torch.set_num_threads(1)

 model = torchvision.models.resnext101_32x8d().eval()

 for batch_size in [1, 64]:
     input = torch.randn(batch_size, 3, 224, 224)
     #warm up
     with torch.no_grad():
         for i in range(5):
             output = model(input)

         fwd_t = 0
         for i in range(10):
             t1 = time.time()
             output = model(input)
             t2 = time.time()
             fwd_t = fwd_t + (t2 - t1)

         time_fwd_avg = fwd_t / 10 * 1000
         print("Throughput of resnext101 with batch_size = %d is %10.2f (imgs/s)" % (batch_size, batch_size * 1000/              time_fwd_avg ))
```
Before:
```
Throughput of resnext101 with batch_size = 1 is       7.89 (imgs/s)
Throughput of resnext101 with batch_size = 64 is      13.02 (imgs/s)

num_threads =1
Throughput of resnext101 with batch_size = 1 is       2.97 (imgs/s)
Throughput of resnext101 with batch_size = 64 is       2.75 (imgs/s)
```
After:
```
Throughput of resnext101 with batch_size = 1 is       8.95 (imgs/s)
Throughput of resnext101 with batch_size = 64 is      15.52 (imgs/s)

num_threads = 1
Throughput of resnext101 with batch_size = 1 is       3.10 (imgs/s)
Throughput of resnext101 with batch_size = 64 is       2.88 (imgs/s)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34530

Differential Revision: D20479560

Pulled By: ngimel

fbshipit-source-id: 2e788ebcd814556116c90553ec61159eeffb3c16
This commit is contained in:
xiaobingsuper
2020-03-17 09:19:20 -07:00
committed by Facebook GitHub Bot
parent a8ca340ad6
commit acbca57d18
3 changed files with 137 additions and 55 deletions

View File

@ -8,6 +8,7 @@
#include <ATen/detail/CUDAHooksInterface.h> #include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/native/TensorIterator.h> #include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h> #include <ATen/native/cpu/Loops.h>
#include <ATen/native/batch_norm.h>
#include <vector> #include <vector>
@ -15,6 +16,8 @@ static const int MIOPEN_DIM_MAX = 5;
namespace at { namespace native { namespace at { namespace native {
DEFINE_DISPATCH(batch_norm_cpu_inference_contiguous_stub);
namespace { namespace {
void check_dims_match_num_input_features(const char* arg_name, int64_t expected, int64_t actual){ void check_dims_match_num_input_features(const char* arg_name, int64_t expected, int64_t actual){
TORCH_CHECK(actual == expected, TORCH_CHECK(actual == expected,
@ -87,59 +90,6 @@ void batch_norm_cpu_inference_collect_linear_and_constant_terms(
} }
} }
/// A fast path for CPU inference when all tensors are contiguous.
/// This code achieves machine bandwidth peak without AVX support.
/// If this changes for future architectures, we can move it to the cpu/
/// directory.
template<typename scalar_t>
void batch_norm_cpu_inference_contiguous(Tensor& output, const Tensor& input,
const Tensor& weight /* optional */, const Tensor& bias /* optional */,
const Tensor& mean, const Tensor& variance, double eps) {
int64_t n_batch = input.size(0);
int64_t n_channel = input.size(1);
int64_t image_size = input.numel() / n_batch / n_channel;
scalar_t* output_data = output.data_ptr<scalar_t>();
const scalar_t* input_data = input.data_ptr<scalar_t>();
Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
scalar_t* alpha_data = alpha.data_ptr<scalar_t>();
scalar_t* beta_data = beta.data_ptr<scalar_t>();
batch_norm_cpu_inference_collect_linear_and_constant_terms<scalar_t>(
alpha_data, beta_data, n_channel, weight, bias, mean, variance, eps);
// Apply the linear terms to the input,
// output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c)
// No need to use parallel_for as this function is supposed to be
// memory-limited.
// Keep the loop struture simple to make sure compiler vectorization kicks in.
if (image_size != 1) {
for (int64_t n = 0; n < n_batch; ++n) {
for (int64_t c = 0; c < n_channel; ++c) {
for (int64_t i = 0; i < image_size; ++i) {
// Keep all the offset calculation within the inner loop for
// simplicity. Compilers are very good at hoisting the common part
// outside.
int64_t offset = n * n_channel * image_size + c * image_size + i;
output_data[offset] = input_data[offset] * alpha_data[c] +
beta_data[c];
}
}
}
} else {
// image_size == 1
for (int64_t n = 0; n < n_batch; ++n) {
for (int64_t c = 0; c < n_channel; ++c) {
int64_t offset = n * n_channel + c;
output_data[offset] = input_data[offset] * alpha_data[c] + beta_data[c];
}
}
}
}
/// A fast path for CPU inference when all tensors are channels last contiguous. /// A fast path for CPU inference when all tensors are channels last contiguous.
/// This code achieves machine bandwidth peak without AVX support. /// This code achieves machine bandwidth peak without AVX support.
/// If this changes for future architectures, we can move it to the cpu/ /// If this changes for future architectures, we can move it to the cpu/
@ -207,8 +157,8 @@ std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
&& running_var.is_contiguous()) { && running_var.is_contiguous()) {
Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
batch_norm_cpu_inference_contiguous<scalar_t>( batch_norm_cpu_inference_contiguous_stub(kCPU, output, input, weight,
output, input, weight, bias, running_mean, running_var, eps); bias, running_mean, running_var, eps);
return std::make_tuple(output, save_mean, save_invstd); return std::make_tuple(output, save_mean, save_invstd);
} }

View File

@ -0,0 +1,18 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>
namespace at {
namespace native {
using batch_norm_fn = void (*)(Tensor&, const Tensor&, const Tensor&,
const Tensor&, const Tensor&, const Tensor&, double);
DECLARE_DISPATCH(batch_norm_fn, batch_norm_cpu_inference_contiguous_stub);
} // namespace native
} // namespace at

View File

@ -0,0 +1,114 @@
#include <ATen/native/batch_norm.h>
#include <ATen/ATen.h>
#include <ATen/CPUApplyUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
namespace at { namespace native {
namespace {
using namespace vec256;
template<typename scalar_t>
void batch_norm_cpu_inference_collect_linear_and_constant_terms(
TensorAccessor<scalar_t, 1> alpha, TensorAccessor<scalar_t, 1> beta, int64_t n_channel,
const Tensor& weight /* optional */, const Tensor& bias /* optional */,
const Tensor& mean, const Tensor& variance, double eps) {
const scalar_t* weight_data = weight.defined() ? weight.data_ptr<scalar_t>() : nullptr;
const scalar_t* bias_data = bias.defined() ? bias.data_ptr<scalar_t>() : nullptr;
auto mean_data = mean.accessor<scalar_t, 1>();
auto var_data = variance.accessor<scalar_t, 1>();
/// Collect the linear and constant terms regarding the input.
/// output(n, c, h, w)
/// = (input(n, c, h, w) - mean(c)) / sqrt(var(c) + eps) * weight(c)
/// + bias(c)
/// = input(n, c, h, w) * inv_var(c) * weight(c)
/// - mean(c) * inv_var(c) * weight(c) + bias(c),
/// where inv_var(c) = 1 / sqrt(var(c) + eps).
/// So the linear term, alpha(c) = inv_var(c) * weight(c),
/// the constant term beta(c) = bias(c) - mean(c) * inv_var(c) * weight(c)
/// Note that this is only a good idea if (input_size >> c), in degenerate
/// cases where image_size == 1 && batch_size == 1, it is slow.
for (int64_t c = 0; c < n_channel; c++) {
scalar_t inv_var = 1 / std::sqrt(var_data[c] + static_cast<scalar_t>(eps));
scalar_t weight_v = weight_data ? weight_data[c] : 1;
scalar_t bias_v = bias_data ? bias_data[c] : 0;
alpha[c] = inv_var * weight_v;
beta[c] = bias_v - mean_data[c] * alpha[c];
}
}
/// A fast path for CPU inference when all tensors are contiguous.
template<typename scalar_t>
void batch_norm_cpu_inference_contiguous_impl(Tensor& output,
const Tensor& input, const Tensor& weight, const Tensor& bias,
const Tensor& mean, const Tensor& variance, double eps) {
using Vec = Vec256<scalar_t>;
int64_t n_batch = input.size(0);
int64_t n_channel = input.size(1);
int64_t image_size = input.numel() / n_batch / n_channel;
Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto alpha_data = alpha.accessor<scalar_t, 1>();
auto beta_data = beta.accessor<scalar_t, 1>();
batch_norm_cpu_inference_collect_linear_and_constant_terms<scalar_t>(
alpha_data, beta_data, n_channel, weight, bias, mean, variance, eps);
scalar_t* output_data = output.data_ptr<scalar_t>();
const scalar_t* input_data = input.data_ptr<scalar_t>();
// Apply the linear terms to the input,
// output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c)
// No need to use parallel_for as this function is supposed to be
// memory-limited.
if (image_size != 1) {
const int64_t n_offset = n_channel * image_size;
const int64_t loop_size = image_size - (image_size % Vec::size());
for (int64_t n = 0; n < n_batch; n++) {
for (int64_t c = 0; c < n_channel; c++) {
const Vec alpha_vec(alpha_data[c]);
const Vec beta_vec(beta_data[c]);
int64_t offset = n * n_offset + c * image_size;
int64_t d = 0;
for (; d < loop_size; d += Vec::size()) {
Vec data_vec = Vec::loadu(input_data + offset + d);
Vec output_vec = data_vec * alpha_vec + beta_vec;
output_vec.store(output_data + offset + d);
}
if (image_size - d > 0) {
Vec data_vec = Vec::loadu(input_data + offset + d, image_size - d);
Vec output_vec = data_vec * alpha_vec + beta_vec;
output_vec.store(output_data + offset + d, image_size - d);
}
}
}
} else {
// image_size == 1
for (int64_t n = 0; n < n_batch; ++n) {
for (int64_t c = 0; c < n_channel; ++c) {
int64_t offset = n * n_channel + c;
output_data[offset] = input_data[offset] * alpha_data[c] + beta_data[c];
}
}
}
}
void batch_norm_cpu_inference_contiguous_kernel(Tensor& output, const Tensor& input,
const Tensor& weight, const Tensor& bias, const Tensor& mean, const Tensor& variance, double eps) {
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_inference_contiguous", [&] {
batch_norm_cpu_inference_contiguous_impl<scalar_t>(output, input, weight, bias, mean, variance, eps);
});
}
}// anonymous namespace
REGISTER_DISPATCH(batch_norm_cpu_inference_contiguous_stub, &batch_norm_cpu_inference_contiguous_kernel);
}} // namespace at::native