optimize gather performance for gnn usage on CPU (#87586)

On classic pyg user case for message passing, `gather` has `index` tensor in a broadcasted shape, e.g. with shape `5000, 128` and stride `[1, 0]`. That indicated gather is done on each row of the self tensor. The current implementation will try to parallel on the inner dimension which is bad performance for CPU and unable to be vectorized.

This PR addressed this use case and optimize in a similar manner to index_select, parallel on outer dimension of `index` and do vectorized copy on inner dimension.

Performance benchmarking on Xeon Icelake single socket on `GCN`: the `gather` reduced from `150.787ms` to `10.926ms`, after this optimization, `gather` will no longer be the major bottleneck for training of GNN models when `EdgeIndex` is in COO format.

for more details, please refer to https://github.com/pyg-team/pytorch_geometric/issues/4891#issuecomment-1288423705

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87586
Approved by: https://github.com/rusty1s, https://github.com/malfet
This commit is contained in:
mingfeima
2023-01-11 21:16:23 +08:00
committed by PyTorch MergeBot
parent f8026413f5
commit dc6916b341
4 changed files with 83 additions and 2 deletions

View File

@ -494,6 +494,7 @@ DEFINE_DISPATCH(scatter_reduce_two_stub);
DEFINE_DISPATCH(scatter_add_expanded_index_stub);
DEFINE_DISPATCH(scatter_reduce_expanded_index_stub);
DEFINE_DISPATCH(gather_expanded_index_stub);
static bool all_strides_match(TensorList tensors) {
TORCH_CHECK(tensors.size() >= 1);
@ -1485,7 +1486,11 @@ TORCH_IMPL_FUNC(gather_out)
(const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad, const Tensor& result) {
if (index.numel() == 0) return;
dim = at::maybe_wrap_dim(dim, self.dim());
gather_stub(result.device().type(), result, self, dim, index);
if (can_use_expanded_index_path(result, dim, index, self, /*is_scatter_like=*/false)) {
gather_expanded_index_stub(result.device().type(), result, self, index);
} else {
gather_stub(result.device().type(), result, self, dim, index);
}
}
Tensor gather_backward(const Tensor& grad, const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad) {

View File

@ -52,7 +52,7 @@ static inline bool can_use_expanded_index_path(
}
const auto st = self.scalar_type();
if (!(c10::isFloatingType(st) || st == ScalarType::Half)) {
if (!(c10::isFloatingType(st)) || st == ScalarType::Half) {
return false;
}
@ -96,8 +96,10 @@ static inline bool can_use_expanded_index_path(
using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&);
using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const SCATTER_GATHER_OP& reduce, bool);
using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&);
DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub);
DECLARE_DISPATCH(scatter_reduce_expanded_index_fn, scatter_reduce_expanded_index_stub);
DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub);
}} // namespace at::native

View File

@ -750,6 +750,44 @@ void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index,
});
}
template <typename scalar_t>
void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index, const Tensor& self) {
int64_t* index_data = index.data_ptr<int64_t>();
scalar_t* result_data = result.data_ptr<scalar_t>();
scalar_t* self_data = self.data_ptr<scalar_t>();
const int64_t M = ensure_nonempty_size(result, 0);
const int64_t N = ensure_nonempty_size(self, 0);
const int64_t K = index.numel() / M;
const int64_t index_upper_bound = N;
using Vec = vec::Vectorized<scalar_t>;
int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / K);
at::parallel_for(0, M, grain_size, [&](int64_t begin, int64_t end) {
for (const auto m : c10::irange(begin, end)) {
scalar_t* result_ptr = result_data + m * K;
int64_t index = index_data[m];
TORCH_CHECK(index >= 0 && index < index_upper_bound,
"index ", index,
" is out of bounds for dimension ", 0,
" with size ", index_upper_bound);
scalar_t* self_ptr = self_data + index * K;
int64_t d = 0;
for (; d < K - (K % Vec::size()); d += Vec::size()) {
Vec out_vec = Vec::loadu(self_ptr + d);
out_vec.store(result_ptr + d);
}
#if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
# pragma unroll
#endif
for (; d < K; d++) {
result_ptr[d] = self_ptr[d];
}
}
});
}
void scatter_add_expanded_index_kernel(const Tensor& self, const Tensor& index, const Tensor& src) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, self.scalar_type(), "scatter_add_expanded_index", [&] {
@ -782,6 +820,13 @@ void scatter_reduce_expanded_index_kernel(
});
}
void gather_expanded_index_kernel(const Tensor& result, const Tensor& self, const Tensor& index) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, self.scalar_type(), "gather_expanded_index", [&] {
cpu_gather_expanded_index_kernel<scalar_t>(result, index, self);
});
}
void gather_cpu_kernel(const Tensor& result, const Tensor& self, int64_t dim, const Tensor& index) {
cpu_scatter_gather_base_kernel</*is_scatter_like=*/false>()(
result, dim, index, self,
@ -875,5 +920,6 @@ REGISTER_DISPATCH(scatter_reduce_two_stub, &scatter_reduce_two_cpu_kernel);
// fast paths for GNN usage
REGISTER_DISPATCH(scatter_add_expanded_index_stub, &scatter_add_expanded_index_kernel);
REGISTER_DISPATCH(scatter_reduce_expanded_index_stub, &scatter_reduce_expanded_index_kernel);
REGISTER_DISPATCH(gather_expanded_index_stub, &gather_expanded_index_kernel);
}} // namespace at::native

View File

@ -305,6 +305,34 @@ class TestScatterGather(TestCase):
helper([50, 8, 7], 100)
helper([50, 3, 4, 5], 100)
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
def test_gather_expanded_index(self, device, dtype):
def helper(input_size, idx_size):
input = torch.randn(input_size, device=device).to(dtype=dtype)
input2 = input.clone()
shape = [1] * len(input_size)
shape[0] = idx_size
dim_size = input_size[0]
idx = torch.randint(0, dim_size, shape)
# Test the fast path on gather when index is expanded
expanded_shape = input_size
expanded_shape[0] = idx_size
idx = idx.expand(expanded_shape)
idx2 = idx.contiguous()
out = torch.gather(input, 0, idx)
out2 = torch.gather(input2, 0, idx2)
self.assertEqual(out, out2)
helper([50, 17], 100)
helper([50, 1], 100)
helper([50, 8, 7], 100)
helper([50, 3, 4, 5], 100)
# Generic Device Test Framework instantation, see
# https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
# for details.