mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
optimize gather performance for gnn usage on CPU (#87586)
On classic pyg user case for message passing, `gather` has `index` tensor in a broadcasted shape, e.g. with shape `5000, 128` and stride `[1, 0]`. That indicated gather is done on each row of the self tensor. The current implementation will try to parallel on the inner dimension which is bad performance for CPU and unable to be vectorized. This PR addressed this use case and optimize in a similar manner to index_select, parallel on outer dimension of `index` and do vectorized copy on inner dimension. Performance benchmarking on Xeon Icelake single socket on `GCN`: the `gather` reduced from `150.787ms` to `10.926ms`, after this optimization, `gather` will no longer be the major bottleneck for training of GNN models when `EdgeIndex` is in COO format. for more details, please refer to https://github.com/pyg-team/pytorch_geometric/issues/4891#issuecomment-1288423705 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87586 Approved by: https://github.com/rusty1s, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
f8026413f5
commit
dc6916b341
@ -494,6 +494,7 @@ DEFINE_DISPATCH(scatter_reduce_two_stub);
|
||||
|
||||
DEFINE_DISPATCH(scatter_add_expanded_index_stub);
|
||||
DEFINE_DISPATCH(scatter_reduce_expanded_index_stub);
|
||||
DEFINE_DISPATCH(gather_expanded_index_stub);
|
||||
|
||||
static bool all_strides_match(TensorList tensors) {
|
||||
TORCH_CHECK(tensors.size() >= 1);
|
||||
@ -1485,7 +1486,11 @@ TORCH_IMPL_FUNC(gather_out)
|
||||
(const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad, const Tensor& result) {
|
||||
if (index.numel() == 0) return;
|
||||
dim = at::maybe_wrap_dim(dim, self.dim());
|
||||
gather_stub(result.device().type(), result, self, dim, index);
|
||||
if (can_use_expanded_index_path(result, dim, index, self, /*is_scatter_like=*/false)) {
|
||||
gather_expanded_index_stub(result.device().type(), result, self, index);
|
||||
} else {
|
||||
gather_stub(result.device().type(), result, self, dim, index);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor gather_backward(const Tensor& grad, const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad) {
|
||||
|
@ -52,7 +52,7 @@ static inline bool can_use_expanded_index_path(
|
||||
}
|
||||
|
||||
const auto st = self.scalar_type();
|
||||
if (!(c10::isFloatingType(st) || st == ScalarType::Half)) {
|
||||
if (!(c10::isFloatingType(st)) || st == ScalarType::Half) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -96,8 +96,10 @@ static inline bool can_use_expanded_index_path(
|
||||
|
||||
using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&);
|
||||
using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const SCATTER_GATHER_OP& reduce, bool);
|
||||
using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&);
|
||||
|
||||
DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub);
|
||||
DECLARE_DISPATCH(scatter_reduce_expanded_index_fn, scatter_reduce_expanded_index_stub);
|
||||
DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub);
|
||||
|
||||
}} // namespace at::native
|
||||
|
@ -750,6 +750,44 @@ void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index,
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index, const Tensor& self) {
|
||||
int64_t* index_data = index.data_ptr<int64_t>();
|
||||
scalar_t* result_data = result.data_ptr<scalar_t>();
|
||||
scalar_t* self_data = self.data_ptr<scalar_t>();
|
||||
|
||||
const int64_t M = ensure_nonempty_size(result, 0);
|
||||
const int64_t N = ensure_nonempty_size(self, 0);
|
||||
const int64_t K = index.numel() / M;
|
||||
|
||||
const int64_t index_upper_bound = N;
|
||||
|
||||
using Vec = vec::Vectorized<scalar_t>;
|
||||
int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / K);
|
||||
at::parallel_for(0, M, grain_size, [&](int64_t begin, int64_t end) {
|
||||
for (const auto m : c10::irange(begin, end)) {
|
||||
scalar_t* result_ptr = result_data + m * K;
|
||||
int64_t index = index_data[m];
|
||||
TORCH_CHECK(index >= 0 && index < index_upper_bound,
|
||||
"index ", index,
|
||||
" is out of bounds for dimension ", 0,
|
||||
" with size ", index_upper_bound);
|
||||
scalar_t* self_ptr = self_data + index * K;
|
||||
int64_t d = 0;
|
||||
for (; d < K - (K % Vec::size()); d += Vec::size()) {
|
||||
Vec out_vec = Vec::loadu(self_ptr + d);
|
||||
out_vec.store(result_ptr + d);
|
||||
}
|
||||
#if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
|
||||
# pragma unroll
|
||||
#endif
|
||||
for (; d < K; d++) {
|
||||
result_ptr[d] = self_ptr[d];
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void scatter_add_expanded_index_kernel(const Tensor& self, const Tensor& index, const Tensor& src) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
ScalarType::BFloat16, self.scalar_type(), "scatter_add_expanded_index", [&] {
|
||||
@ -782,6 +820,13 @@ void scatter_reduce_expanded_index_kernel(
|
||||
});
|
||||
}
|
||||
|
||||
void gather_expanded_index_kernel(const Tensor& result, const Tensor& self, const Tensor& index) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
ScalarType::BFloat16, self.scalar_type(), "gather_expanded_index", [&] {
|
||||
cpu_gather_expanded_index_kernel<scalar_t>(result, index, self);
|
||||
});
|
||||
}
|
||||
|
||||
void gather_cpu_kernel(const Tensor& result, const Tensor& self, int64_t dim, const Tensor& index) {
|
||||
cpu_scatter_gather_base_kernel</*is_scatter_like=*/false>()(
|
||||
result, dim, index, self,
|
||||
@ -875,5 +920,6 @@ REGISTER_DISPATCH(scatter_reduce_two_stub, &scatter_reduce_two_cpu_kernel);
|
||||
// fast paths for GNN usage
|
||||
REGISTER_DISPATCH(scatter_add_expanded_index_stub, &scatter_add_expanded_index_kernel);
|
||||
REGISTER_DISPATCH(scatter_reduce_expanded_index_stub, &scatter_reduce_expanded_index_kernel);
|
||||
REGISTER_DISPATCH(gather_expanded_index_stub, &gather_expanded_index_kernel);
|
||||
|
||||
}} // namespace at::native
|
||||
|
@ -305,6 +305,34 @@ class TestScatterGather(TestCase):
|
||||
helper([50, 8, 7], 100)
|
||||
helper([50, 3, 4, 5], 100)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
def test_gather_expanded_index(self, device, dtype):
|
||||
def helper(input_size, idx_size):
|
||||
input = torch.randn(input_size, device=device).to(dtype=dtype)
|
||||
input2 = input.clone()
|
||||
|
||||
shape = [1] * len(input_size)
|
||||
shape[0] = idx_size
|
||||
dim_size = input_size[0]
|
||||
idx = torch.randint(0, dim_size, shape)
|
||||
|
||||
# Test the fast path on gather when index is expanded
|
||||
expanded_shape = input_size
|
||||
expanded_shape[0] = idx_size
|
||||
idx = idx.expand(expanded_shape)
|
||||
idx2 = idx.contiguous()
|
||||
|
||||
out = torch.gather(input, 0, idx)
|
||||
out2 = torch.gather(input2, 0, idx2)
|
||||
|
||||
self.assertEqual(out, out2)
|
||||
|
||||
helper([50, 17], 100)
|
||||
helper([50, 1], 100)
|
||||
helper([50, 8, 7], 100)
|
||||
helper([50, 3, 4, 5], 100)
|
||||
|
||||
# Generic Device Test Framework instantation, see
|
||||
# https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
|
||||
# for details.
|
||||
|
Reference in New Issue
Block a user