Fix torch.cat() performance regression on single core CPU (#33534)

Summary:
This PR addresses the performance regression on `torch.cat()` on CPU with single thread.
Previous optimization https://github.com/pytorch/pytorch/issues/30806 introduced regression for several cases on pytorch operator benchmark.
See https://github.com/pytorch/pytorch/issues/33334 for detail.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33534

Differential Revision: D20129963

Pulled By: VitalyFedyunin

fbshipit-source-id: 3fa6cd266978e5b54fa37105555502b77352df3e
This commit is contained in:
Mingfei Ma
2020-02-28 11:18:51 -08:00
committed by Facebook Github Bot
parent 890242254b
commit c6d301220a
4 changed files with 113 additions and 0 deletions

View File

@ -14,12 +14,15 @@
#include <vector>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/CatKernel.h>
#include <ATen/native/Copy.h>
#include <ATen/MemoryOverlap.h>
namespace at {
namespace native {
DEFINE_DISPATCH(cat_serial_stub);
Tensor _reshape_from_tensor(const Tensor& self, const Tensor& shape_tensor) {
TORCH_CHECK(shape_tensor.dim() == 1);
std::vector<int64_t> shape;
@ -83,6 +86,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
// size (i.e. other empty sizes are not skipped).
// FIXME: warn if this is the case
bool allSkipped = true;
bool allContiguous = true;
Tensor notSkippedTensor;
// Inputs cannot alias the output tensor
@ -119,11 +123,17 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
int64_t cat_dim_size = 0;
for (auto const &tensor : tensors) {
if (should_skip(tensor)) {
// don't use fast path for empty tensor
allContiguous = false;
continue;
}
check_cat_shape_except_dim(notSkippedTensor, tensor, dim);
cat_dim_size += tensor.size(dim);
if (!tensor.is_contiguous()) {
allContiguous = false;
}
if (tensor.sizes() != notSkippedTensor.sizes() ||
tensor.strides() != notSkippedTensor.strides()) {
reuse_iterator = false;
@ -135,6 +145,15 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
result_size[dim] = cat_dim_size;
result.resize_(result_size);
// fast path for single thread when both inputs and result are contiguous and not empty
bool use_serial_kernel = result.numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
allContiguous = allContiguous && result.is_contiguous();
ScalarType dtype = notSkippedTensor.scalar_type();
if (use_serial_kernel && allContiguous && (dtype == ScalarType::Double || dtype == ScalarType::Float)) {
cat_serial_stub(kCPU, result, tensors, dim);
return result;
}
int64_t offset = 0;
if (reuse_iterator && result.is_contiguous()) {
auto source_slice = notSkippedTensor;

View File

@ -0,0 +1,75 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/native/cpu/CatKernel.h>
#include <ATen/cpu/vec256/functional.h>
#include <ATen/cpu/vec256/vec256.h>
namespace at { namespace native {
namespace {
struct InputMeta {
void* data_ptr;
int64_t inner_size;
InputMeta(const Tensor& t, int64_t dim, int64_t inner)
: data_ptr(t.data_ptr())
, inner_size(t.size(dim) * inner) {}
};
template <typename scalar_t>
void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) {
auto size = result.sizes().vec();
int64_t outer = 1, inner = 1;
for (int64_t i = 0; i < dim; i++) {
outer *= size[i];
}
for (int64_t i = dim + 1; i < size.size(); i++) {
inner *= size[i];
}
scalar_t* result_data = result.data_ptr<scalar_t>();
int64_t ninputs = tensors.size();
std::vector<InputMeta> inputs;
inputs.reserve(ninputs);
for (auto const &tensor : tensors) {
inputs.emplace_back(tensor, dim, inner);
}
using Vec = vec256::Vec256<scalar_t>;
int64_t offset = 0;
for (int64_t i = 0; i < outer; i++) {
for (int64_t j = 0; j < ninputs; j++) {
scalar_t* result_ptr = result_data + offset;
int64_t local_inner = inputs[j].inner_size;
scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner;
if (local_inner < Vec::size()) {
#ifndef _MSC_VER
# pragma unroll
#endif
for (int64_t k = 0; k < local_inner; k++) {
result_ptr[k] = input_ptr[k];
}
} else {
vec256::map(
[](Vec x) { return x; },
result_ptr,
input_ptr,
local_inner);
}
offset += local_inner;
}
}
}
void cat_serial_kernel(Tensor& result, TensorList tensors, int64_t dim) {
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cat_serial_kernel", [&]() {
cat_serial_kernel_impl<scalar_t>(result, tensors, dim);
});
}
} // anonymous namespace
REGISTER_DISPATCH(cat_serial_stub, &cat_serial_kernel);
}} // at::native

View File

@ -0,0 +1,11 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>
namespace at { namespace native {
using cat_serial_fn = void(*)(Tensor &, TensorList, int64_t);
DECLARE_DISPATCH(cat_serial_fn, cat_serial_stub);
}} // namespace at::native

View File

@ -6492,6 +6492,14 @@ class TestTorchDeviceType(TestCase):
self.assertEqual(a, b)
self.assertEqual(w[:6], y.view(-1)[:6])
def test_cat_out_channels_last(self, device):
x = torch.randn((4, 3, 8, 8))
y = torch.randn(x.shape)
res1 = torch.cat((x, y))
z = res1.clone().contiguous(memory_format=torch.channels_last)
res2 = torch.cat((x, y), out=z)
self.assertEqual(res1, res2)
def test_is_set_to(self, device):
t1 = torch.empty(3, 4, 9, 10, device=device)
t2 = torch.empty(3, 4, 9, 10, device=device)