mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix torch.cat() performance regression on single core CPU (#33534)
Summary: This PR addresses the performance regression on `torch.cat()` on CPU with single thread. Previous optimization https://github.com/pytorch/pytorch/issues/30806 introduced regression for several cases on pytorch operator benchmark. See https://github.com/pytorch/pytorch/issues/33334 for detail. Pull Request resolved: https://github.com/pytorch/pytorch/pull/33534 Differential Revision: D20129963 Pulled By: VitalyFedyunin fbshipit-source-id: 3fa6cd266978e5b54fa37105555502b77352df3e
This commit is contained in:
committed by
Facebook Github Bot
parent
890242254b
commit
c6d301220a
@ -14,12 +14,15 @@
|
||||
#include <vector>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/CatKernel.h>
|
||||
#include <ATen/native/Copy.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
DEFINE_DISPATCH(cat_serial_stub);
|
||||
|
||||
Tensor _reshape_from_tensor(const Tensor& self, const Tensor& shape_tensor) {
|
||||
TORCH_CHECK(shape_tensor.dim() == 1);
|
||||
std::vector<int64_t> shape;
|
||||
@ -83,6 +86,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
|
||||
// size (i.e. other empty sizes are not skipped).
|
||||
// FIXME: warn if this is the case
|
||||
bool allSkipped = true;
|
||||
bool allContiguous = true;
|
||||
Tensor notSkippedTensor;
|
||||
|
||||
// Inputs cannot alias the output tensor
|
||||
@ -119,11 +123,17 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
|
||||
int64_t cat_dim_size = 0;
|
||||
for (auto const &tensor : tensors) {
|
||||
if (should_skip(tensor)) {
|
||||
// don't use fast path for empty tensor
|
||||
allContiguous = false;
|
||||
continue;
|
||||
}
|
||||
check_cat_shape_except_dim(notSkippedTensor, tensor, dim);
|
||||
cat_dim_size += tensor.size(dim);
|
||||
|
||||
if (!tensor.is_contiguous()) {
|
||||
allContiguous = false;
|
||||
}
|
||||
|
||||
if (tensor.sizes() != notSkippedTensor.sizes() ||
|
||||
tensor.strides() != notSkippedTensor.strides()) {
|
||||
reuse_iterator = false;
|
||||
@ -135,6 +145,15 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
|
||||
result_size[dim] = cat_dim_size;
|
||||
result.resize_(result_size);
|
||||
|
||||
// fast path for single thread when both inputs and result are contiguous and not empty
|
||||
bool use_serial_kernel = result.numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
|
||||
allContiguous = allContiguous && result.is_contiguous();
|
||||
ScalarType dtype = notSkippedTensor.scalar_type();
|
||||
if (use_serial_kernel && allContiguous && (dtype == ScalarType::Double || dtype == ScalarType::Float)) {
|
||||
cat_serial_stub(kCPU, result, tensors, dim);
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t offset = 0;
|
||||
if (reuse_iterator && result.is_contiguous()) {
|
||||
auto source_slice = notSkippedTensor;
|
||||
|
75
aten/src/ATen/native/cpu/CatKernel.cpp
Normal file
75
aten/src/ATen/native/cpu/CatKernel.cpp
Normal file
@ -0,0 +1,75 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/cpu/CatKernel.h>
|
||||
#include <ATen/cpu/vec256/functional.h>
|
||||
#include <ATen/cpu/vec256/vec256.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
namespace {
|
||||
|
||||
struct InputMeta {
|
||||
void* data_ptr;
|
||||
int64_t inner_size;
|
||||
|
||||
InputMeta(const Tensor& t, int64_t dim, int64_t inner)
|
||||
: data_ptr(t.data_ptr())
|
||||
, inner_size(t.size(dim) * inner) {}
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) {
|
||||
auto size = result.sizes().vec();
|
||||
int64_t outer = 1, inner = 1;
|
||||
for (int64_t i = 0; i < dim; i++) {
|
||||
outer *= size[i];
|
||||
}
|
||||
for (int64_t i = dim + 1; i < size.size(); i++) {
|
||||
inner *= size[i];
|
||||
}
|
||||
scalar_t* result_data = result.data_ptr<scalar_t>();
|
||||
int64_t ninputs = tensors.size();
|
||||
std::vector<InputMeta> inputs;
|
||||
inputs.reserve(ninputs);
|
||||
for (auto const &tensor : tensors) {
|
||||
inputs.emplace_back(tensor, dim, inner);
|
||||
}
|
||||
|
||||
using Vec = vec256::Vec256<scalar_t>;
|
||||
int64_t offset = 0;
|
||||
for (int64_t i = 0; i < outer; i++) {
|
||||
for (int64_t j = 0; j < ninputs; j++) {
|
||||
scalar_t* result_ptr = result_data + offset;
|
||||
int64_t local_inner = inputs[j].inner_size;
|
||||
scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner;
|
||||
if (local_inner < Vec::size()) {
|
||||
#ifndef _MSC_VER
|
||||
# pragma unroll
|
||||
#endif
|
||||
for (int64_t k = 0; k < local_inner; k++) {
|
||||
result_ptr[k] = input_ptr[k];
|
||||
}
|
||||
} else {
|
||||
vec256::map(
|
||||
[](Vec x) { return x; },
|
||||
result_ptr,
|
||||
input_ptr,
|
||||
local_inner);
|
||||
}
|
||||
offset += local_inner;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cat_serial_kernel(Tensor& result, TensorList tensors, int64_t dim) {
|
||||
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cat_serial_kernel", [&]() {
|
||||
cat_serial_kernel_impl<scalar_t>(result, tensors, dim);
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
REGISTER_DISPATCH(cat_serial_stub, &cat_serial_kernel);
|
||||
|
||||
}} // at::native
|
11
aten/src/ATen/native/cpu/CatKernel.h
Normal file
11
aten/src/ATen/native/cpu/CatKernel.h
Normal file
@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
using cat_serial_fn = void(*)(Tensor &, TensorList, int64_t);
|
||||
DECLARE_DISPATCH(cat_serial_fn, cat_serial_stub);
|
||||
|
||||
}} // namespace at::native
|
@ -6492,6 +6492,14 @@ class TestTorchDeviceType(TestCase):
|
||||
self.assertEqual(a, b)
|
||||
self.assertEqual(w[:6], y.view(-1)[:6])
|
||||
|
||||
def test_cat_out_channels_last(self, device):
|
||||
x = torch.randn((4, 3, 8, 8))
|
||||
y = torch.randn(x.shape)
|
||||
res1 = torch.cat((x, y))
|
||||
z = res1.clone().contiguous(memory_format=torch.channels_last)
|
||||
res2 = torch.cat((x, y), out=z)
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
def test_is_set_to(self, device):
|
||||
t1 = torch.empty(3, 4, 9, 10, device=device)
|
||||
t2 = torch.empty(3, 4, 9, 10, device=device)
|
||||
|
Reference in New Issue
Block a user