optimize gather performance for gnn usage on CPU (#87586)

On classic pyg user case for message passing, `gather` has `index` tensor in a broadcasted shape, e.g. with shape `5000, 128` and stride `[1, 0]`. That indicated gather is done on each row of the self tensor. The current implementation will try to parallel on the inner dimension which is bad performance for CPU and unable to be vectorized.

This PR addressed this use case and optimize in a similar manner to index_select, parallel on outer dimension of `index` and do vectorized copy on inner dimension.

Performance benchmarking on Xeon Icelake single socket on `GCN`: the `gather` reduced from `150.787ms` to `10.926ms`, after this optimization, `gather` will no longer be the major bottleneck for training of GNN models when `EdgeIndex` is in COO format.

for more details, please refer to https://github.com/pyg-team/pytorch_geometric/issues/4891#issuecomment-1288423705

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87586
Approved by: https://github.com/rusty1s, https://github.com/malfet
This commit is contained in:
mingfeima
2023-01-11 21:16:23 +08:00
committed by PyTorch MergeBot
parent f8026413f5
commit dc6916b341
4 changed files with 83 additions and 2 deletions

View File

@ -305,6 +305,34 @@ class TestScatterGather(TestCase):
helper([50, 8, 7], 100)
helper([50, 3, 4, 5], 100)
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
def test_gather_expanded_index(self, device, dtype):
def helper(input_size, idx_size):
input = torch.randn(input_size, device=device).to(dtype=dtype)
input2 = input.clone()
shape = [1] * len(input_size)
shape[0] = idx_size
dim_size = input_size[0]
idx = torch.randint(0, dim_size, shape)
# Test the fast path on gather when index is expanded
expanded_shape = input_size
expanded_shape[0] = idx_size
idx = idx.expand(expanded_shape)
idx2 = idx.contiguous()
out = torch.gather(input, 0, idx)
out2 = torch.gather(input2, 0, idx2)
self.assertEqual(out, out2)
helper([50, 17], 100)
helper([50, 1], 100)
helper([50, 8, 7], 100)
helper([50, 3, 4, 5], 100)
# Generic Device Test Framework instantation, see
# https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
# for details.