mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
improve batch_norm contiguous case's performance (#34530)
Summary: For batch_norm inference contiguous case, we can get a better performance by manually vectorize it. Test script: ``` X import torch import torch.nn as nn import time torch.manual_seed(0) for n in [1, 10, 100]: for c in [1, 10, 100]: for hw in [1, 10, 200]: m = nn.BatchNorm2d(c, affine=False) m.eval() input = torch.randn(20, c, hw, hw) # warm up for i in range(200): output = m(input) fwd_t = 0 for j in range(1000): t1 = time.time() output = m(input) t2 = time.time() fwd_t = fwd_t + (t2 -t1) fwd_avg = fwd_t / 1000 * 1000 print("size = (%d, %d, %d, %d); compute time is %.4f(ms)" % (n, c, hw, hw, fwd_avg)) ``` Before: ``` size = (1, 1, 1, 1); compute time is 0.0110(ms) size = (1, 1, 10, 10); compute time is 0.0123(ms) size = (1, 1, 200, 200); compute time is 0.8166(ms) size = (1, 10, 1, 1); compute time is 0.0107(ms) size = (1, 10, 10, 10); compute time is 0.0257(ms) size = (1, 10, 200, 200); compute time is 8.7533(ms) size = (1, 100, 1, 1); compute time is 0.0122(ms) size = (1, 100, 10, 10); compute time is 0.1619(ms) size = (1, 100, 200, 200); compute time is 123.5674(ms) size = (10, 1, 1, 1); compute time is 0.0109(ms) size = (10, 1, 10, 10); compute time is 0.0123(ms) size = (10, 1, 200, 200); compute time is 0.5629(ms) size = (10, 10, 1, 1); compute time is 0.0107(ms) size = (10, 10, 10, 10); compute time is 0.0253(ms) size = (10, 10, 200, 200); compute time is 8.7817(ms) size = (10, 100, 1, 1); compute time is 0.0120(ms) size = (10, 100, 10, 10); compute time is 0.1655(ms) size = (10, 100, 200, 200); compute time is 123.2488(ms) size = (100, 1, 1, 1); compute time is 0.0109(ms) size = (100, 1, 10, 10); compute time is 0.0123(ms) size = (100, 1, 200, 200); compute time is 0.5740(ms) size = (100, 10, 1, 1); compute time is 0.0108(ms) size = (100, 10, 10, 10); compute time is 0.0257(ms) size = (100, 10, 200, 200); compute time is 8.7201(ms) size = (100, 100, 1, 1); compute time is 0.0122(ms) size = (100, 100, 10, 10); compute time is 0.1628(ms) size = (100, 100, 200, 200); compute time is 123.1739(ms) ``` After: ``` size = (1, 1, 1, 1); compute time is 0.0105(ms) size = (1, 1, 10, 10); compute time is 0.0114(ms) size = (1, 1, 200, 200); compute time is 0.5771(ms) size = (1, 10, 1, 1); compute time is 0.0105(ms) size = (1, 10, 10, 10); compute time is 0.0160(ms) size = (1, 10, 200, 200); compute time is 6.9851(ms) size = (1, 100, 1, 1); compute time is 0.0122(ms) size = (1, 100, 10, 10); compute time is 0.0848(ms) size = (1, 100, 200, 200); compute time is 98.6758(ms) size = (10, 1, 1, 1); compute time is 0.0105(ms) size = (10, 1, 10, 10); compute time is 0.0115(ms) size = (10, 1, 200, 200); compute time is 0.2690(ms) size = (10, 10, 1, 1); compute time is 0.0105(ms) size = (10, 10, 10, 10); compute time is 0.0159(ms) size = (10, 10, 200, 200); compute time is 6.6946(ms) size = (10, 100, 1, 1); compute time is 0.0123(ms) size = (10, 100, 10, 10); compute time is 0.0854(ms) size = (10, 100, 200, 200); compute time is 98.7327(ms) size = (100, 1, 1, 1); compute time is 0.0107(ms) size = (100, 1, 10, 10); compute time is 0.0116(ms) size = (100, 1, 200, 200); compute time is 0.2681(ms) size = (100, 10, 1, 1); compute time is 0.0104(ms) size = (100, 10, 10, 10); compute time is 0.0159(ms) size = (100, 10, 200, 200); compute time is 6.7507(ms) size = (100, 100, 1, 1); compute time is 0.0124(ms) size = (100, 100, 10, 10); compute time is 0.0852(ms) size = (100, 100, 200, 200); compute time is 98.6866(ms) ``` For real modle Resnext101, we can also get **~20%** performance improvement for large batch size, Test script: ``` import torch import torchvision import torch import time torch.manual_seed(0) #torch.set_num_threads(1) model = torchvision.models.resnext101_32x8d().eval() for batch_size in [1, 64]: input = torch.randn(batch_size, 3, 224, 224) #warm up with torch.no_grad(): for i in range(5): output = model(input) fwd_t = 0 for i in range(10): t1 = time.time() output = model(input) t2 = time.time() fwd_t = fwd_t + (t2 - t1) time_fwd_avg = fwd_t / 10 * 1000 print("Throughput of resnext101 with batch_size = %d is %10.2f (imgs/s)" % (batch_size, batch_size * 1000/ time_fwd_avg )) ``` Before: ``` Throughput of resnext101 with batch_size = 1 is 7.89 (imgs/s) Throughput of resnext101 with batch_size = 64 is 13.02 (imgs/s) num_threads =1 Throughput of resnext101 with batch_size = 1 is 2.97 (imgs/s) Throughput of resnext101 with batch_size = 64 is 2.75 (imgs/s) ``` After: ``` Throughput of resnext101 with batch_size = 1 is 8.95 (imgs/s) Throughput of resnext101 with batch_size = 64 is 15.52 (imgs/s) num_threads = 1 Throughput of resnext101 with batch_size = 1 is 3.10 (imgs/s) Throughput of resnext101 with batch_size = 64 is 2.88 (imgs/s) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/34530 Differential Revision: D20479560 Pulled By: ngimel fbshipit-source-id: 2e788ebcd814556116c90553ec61159eeffb3c16
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a8ca340ad6
commit
acbca57d18
@ -8,6 +8,7 @@
|
|||||||
#include <ATen/detail/CUDAHooksInterface.h>
|
#include <ATen/detail/CUDAHooksInterface.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/cpu/Loops.h>
|
#include <ATen/native/cpu/Loops.h>
|
||||||
|
#include <ATen/native/batch_norm.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -15,6 +16,8 @@ static const int MIOPEN_DIM_MAX = 5;
|
|||||||
|
|
||||||
namespace at { namespace native {
|
namespace at { namespace native {
|
||||||
|
|
||||||
|
DEFINE_DISPATCH(batch_norm_cpu_inference_contiguous_stub);
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void check_dims_match_num_input_features(const char* arg_name, int64_t expected, int64_t actual){
|
void check_dims_match_num_input_features(const char* arg_name, int64_t expected, int64_t actual){
|
||||||
TORCH_CHECK(actual == expected,
|
TORCH_CHECK(actual == expected,
|
||||||
@ -87,59 +90,6 @@ void batch_norm_cpu_inference_collect_linear_and_constant_terms(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A fast path for CPU inference when all tensors are contiguous.
|
|
||||||
/// This code achieves machine bandwidth peak without AVX support.
|
|
||||||
/// If this changes for future architectures, we can move it to the cpu/
|
|
||||||
/// directory.
|
|
||||||
template<typename scalar_t>
|
|
||||||
void batch_norm_cpu_inference_contiguous(Tensor& output, const Tensor& input,
|
|
||||||
const Tensor& weight /* optional */, const Tensor& bias /* optional */,
|
|
||||||
const Tensor& mean, const Tensor& variance, double eps) {
|
|
||||||
|
|
||||||
int64_t n_batch = input.size(0);
|
|
||||||
int64_t n_channel = input.size(1);
|
|
||||||
int64_t image_size = input.numel() / n_batch / n_channel;
|
|
||||||
|
|
||||||
scalar_t* output_data = output.data_ptr<scalar_t>();
|
|
||||||
const scalar_t* input_data = input.data_ptr<scalar_t>();
|
|
||||||
|
|
||||||
Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
|
||||||
Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
|
||||||
scalar_t* alpha_data = alpha.data_ptr<scalar_t>();
|
|
||||||
scalar_t* beta_data = beta.data_ptr<scalar_t>();
|
|
||||||
|
|
||||||
batch_norm_cpu_inference_collect_linear_and_constant_terms<scalar_t>(
|
|
||||||
alpha_data, beta_data, n_channel, weight, bias, mean, variance, eps);
|
|
||||||
|
|
||||||
// Apply the linear terms to the input,
|
|
||||||
// output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c)
|
|
||||||
// No need to use parallel_for as this function is supposed to be
|
|
||||||
// memory-limited.
|
|
||||||
// Keep the loop struture simple to make sure compiler vectorization kicks in.
|
|
||||||
if (image_size != 1) {
|
|
||||||
for (int64_t n = 0; n < n_batch; ++n) {
|
|
||||||
for (int64_t c = 0; c < n_channel; ++c) {
|
|
||||||
for (int64_t i = 0; i < image_size; ++i) {
|
|
||||||
// Keep all the offset calculation within the inner loop for
|
|
||||||
// simplicity. Compilers are very good at hoisting the common part
|
|
||||||
// outside.
|
|
||||||
int64_t offset = n * n_channel * image_size + c * image_size + i;
|
|
||||||
output_data[offset] = input_data[offset] * alpha_data[c] +
|
|
||||||
beta_data[c];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// image_size == 1
|
|
||||||
for (int64_t n = 0; n < n_batch; ++n) {
|
|
||||||
for (int64_t c = 0; c < n_channel; ++c) {
|
|
||||||
int64_t offset = n * n_channel + c;
|
|
||||||
output_data[offset] = input_data[offset] * alpha_data[c] + beta_data[c];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A fast path for CPU inference when all tensors are channels last contiguous.
|
/// A fast path for CPU inference when all tensors are channels last contiguous.
|
||||||
/// This code achieves machine bandwidth peak without AVX support.
|
/// This code achieves machine bandwidth peak without AVX support.
|
||||||
/// If this changes for future architectures, we can move it to the cpu/
|
/// If this changes for future architectures, we can move it to the cpu/
|
||||||
@ -207,8 +157,8 @@ std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
|
|||||||
&& running_var.is_contiguous()) {
|
&& running_var.is_contiguous()) {
|
||||||
|
|
||||||
Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||||
batch_norm_cpu_inference_contiguous<scalar_t>(
|
batch_norm_cpu_inference_contiguous_stub(kCPU, output, input, weight,
|
||||||
output, input, weight, bias, running_mean, running_var, eps);
|
bias, running_mean, running_var, eps);
|
||||||
return std::make_tuple(output, save_mean, save_invstd);
|
return std::make_tuple(output, save_mean, save_invstd);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
18
aten/src/ATen/native/batch_norm.h
Normal file
18
aten/src/ATen/native/batch_norm.h
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/native/DispatchStub.h>
|
||||||
|
|
||||||
|
namespace at {
|
||||||
|
|
||||||
|
namespace native {
|
||||||
|
|
||||||
|
using batch_norm_fn = void (*)(Tensor&, const Tensor&, const Tensor&,
|
||||||
|
const Tensor&, const Tensor&, const Tensor&, double);
|
||||||
|
|
||||||
|
DECLARE_DISPATCH(batch_norm_fn, batch_norm_cpu_inference_contiguous_stub);
|
||||||
|
|
||||||
|
} // namespace native
|
||||||
|
|
||||||
|
} // namespace at
|
||||||
|
|
114
aten/src/ATen/native/cpu/batch_norm_kernel.cpp
Normal file
114
aten/src/ATen/native/cpu/batch_norm_kernel.cpp
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
#include <ATen/native/batch_norm.h>
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/CPUApplyUtils.h>
|
||||||
|
#include <ATen/Dispatch.h>
|
||||||
|
#include <ATen/native/TensorIterator.h>
|
||||||
|
#include <ATen/native/cpu/Loops.h>
|
||||||
|
|
||||||
|
namespace at { namespace native {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using namespace vec256;
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
void batch_norm_cpu_inference_collect_linear_and_constant_terms(
|
||||||
|
TensorAccessor<scalar_t, 1> alpha, TensorAccessor<scalar_t, 1> beta, int64_t n_channel,
|
||||||
|
const Tensor& weight /* optional */, const Tensor& bias /* optional */,
|
||||||
|
const Tensor& mean, const Tensor& variance, double eps) {
|
||||||
|
|
||||||
|
const scalar_t* weight_data = weight.defined() ? weight.data_ptr<scalar_t>() : nullptr;
|
||||||
|
const scalar_t* bias_data = bias.defined() ? bias.data_ptr<scalar_t>() : nullptr;
|
||||||
|
auto mean_data = mean.accessor<scalar_t, 1>();
|
||||||
|
auto var_data = variance.accessor<scalar_t, 1>();
|
||||||
|
|
||||||
|
/// Collect the linear and constant terms regarding the input.
|
||||||
|
/// output(n, c, h, w)
|
||||||
|
/// = (input(n, c, h, w) - mean(c)) / sqrt(var(c) + eps) * weight(c)
|
||||||
|
/// + bias(c)
|
||||||
|
/// = input(n, c, h, w) * inv_var(c) * weight(c)
|
||||||
|
/// - mean(c) * inv_var(c) * weight(c) + bias(c),
|
||||||
|
/// where inv_var(c) = 1 / sqrt(var(c) + eps).
|
||||||
|
/// So the linear term, alpha(c) = inv_var(c) * weight(c),
|
||||||
|
/// the constant term beta(c) = bias(c) - mean(c) * inv_var(c) * weight(c)
|
||||||
|
/// Note that this is only a good idea if (input_size >> c), in degenerate
|
||||||
|
/// cases where image_size == 1 && batch_size == 1, it is slow.
|
||||||
|
for (int64_t c = 0; c < n_channel; c++) {
|
||||||
|
scalar_t inv_var = 1 / std::sqrt(var_data[c] + static_cast<scalar_t>(eps));
|
||||||
|
scalar_t weight_v = weight_data ? weight_data[c] : 1;
|
||||||
|
scalar_t bias_v = bias_data ? bias_data[c] : 0;
|
||||||
|
alpha[c] = inv_var * weight_v;
|
||||||
|
beta[c] = bias_v - mean_data[c] * alpha[c];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A fast path for CPU inference when all tensors are contiguous.
|
||||||
|
template<typename scalar_t>
|
||||||
|
void batch_norm_cpu_inference_contiguous_impl(Tensor& output,
|
||||||
|
const Tensor& input, const Tensor& weight, const Tensor& bias,
|
||||||
|
const Tensor& mean, const Tensor& variance, double eps) {
|
||||||
|
|
||||||
|
using Vec = Vec256<scalar_t>;
|
||||||
|
int64_t n_batch = input.size(0);
|
||||||
|
int64_t n_channel = input.size(1);
|
||||||
|
int64_t image_size = input.numel() / n_batch / n_channel;
|
||||||
|
|
||||||
|
Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||||
|
Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||||
|
auto alpha_data = alpha.accessor<scalar_t, 1>();
|
||||||
|
auto beta_data = beta.accessor<scalar_t, 1>();
|
||||||
|
|
||||||
|
batch_norm_cpu_inference_collect_linear_and_constant_terms<scalar_t>(
|
||||||
|
alpha_data, beta_data, n_channel, weight, bias, mean, variance, eps);
|
||||||
|
|
||||||
|
scalar_t* output_data = output.data_ptr<scalar_t>();
|
||||||
|
const scalar_t* input_data = input.data_ptr<scalar_t>();
|
||||||
|
|
||||||
|
// Apply the linear terms to the input,
|
||||||
|
// output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c)
|
||||||
|
// No need to use parallel_for as this function is supposed to be
|
||||||
|
// memory-limited.
|
||||||
|
if (image_size != 1) {
|
||||||
|
const int64_t n_offset = n_channel * image_size;
|
||||||
|
const int64_t loop_size = image_size - (image_size % Vec::size());
|
||||||
|
for (int64_t n = 0; n < n_batch; n++) {
|
||||||
|
for (int64_t c = 0; c < n_channel; c++) {
|
||||||
|
const Vec alpha_vec(alpha_data[c]);
|
||||||
|
const Vec beta_vec(beta_data[c]);
|
||||||
|
int64_t offset = n * n_offset + c * image_size;
|
||||||
|
int64_t d = 0;
|
||||||
|
for (; d < loop_size; d += Vec::size()) {
|
||||||
|
Vec data_vec = Vec::loadu(input_data + offset + d);
|
||||||
|
Vec output_vec = data_vec * alpha_vec + beta_vec;
|
||||||
|
output_vec.store(output_data + offset + d);
|
||||||
|
}
|
||||||
|
if (image_size - d > 0) {
|
||||||
|
Vec data_vec = Vec::loadu(input_data + offset + d, image_size - d);
|
||||||
|
Vec output_vec = data_vec * alpha_vec + beta_vec;
|
||||||
|
output_vec.store(output_data + offset + d, image_size - d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// image_size == 1
|
||||||
|
for (int64_t n = 0; n < n_batch; ++n) {
|
||||||
|
for (int64_t c = 0; c < n_channel; ++c) {
|
||||||
|
int64_t offset = n * n_channel + c;
|
||||||
|
output_data[offset] = input_data[offset] * alpha_data[c] + beta_data[c];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void batch_norm_cpu_inference_contiguous_kernel(Tensor& output, const Tensor& input,
|
||||||
|
const Tensor& weight, const Tensor& bias, const Tensor& mean, const Tensor& variance, double eps) {
|
||||||
|
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_inference_contiguous", [&] {
|
||||||
|
batch_norm_cpu_inference_contiguous_impl<scalar_t>(output, input, weight, bias, mean, variance, eps);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
}// anonymous namespace
|
||||||
|
|
||||||
|
REGISTER_DISPATCH(batch_norm_cpu_inference_contiguous_stub, &batch_norm_cpu_inference_contiguous_kernel);
|
||||||
|
|
||||||
|
}} // namespace at::native
|
Reference in New Issue
Block a user